Comment fonctionne tf.app.run ()?

150

Comment tf.app.run()fonctionne la démo de Tensorflow?

Dans tensorflow/models/rnn/translate/translate.py, il y a un appel à tf.app.run(). Comment cela est-il géré?

if __name__ == "__main__":
    tf.app.run() 
Anurag Ranjan
la source

Réponses:

135
if __name__ == "__main__":

signifie que le fichier courant est exécuté sous un shell au lieu d'être importé en tant que module.

tf.app.run()

Comme vous pouvez le voir à travers le fichier app.py

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

Coupons ligne par ligne:

flags_passthrough = f._parse_flags(args=args)

Cela garantit que l'argument que vous passez à travers la ligne de commande est valide, par exemple, en python my_model.py --data_dir='...' --max_iteration=10000fait, cette fonctionnalité est implémentée sur la base du argparsemodule standard python .

main = main or sys.modules['__main__'].main

Le premier mainà droite de =est le premier argument de la fonction courante run(main=None, argv=None) . Alors que sys.modules['__main__']signifie le fichier en cours d'exécution (par exemple my_model.py).

Il y a donc deux cas:

  1. Vous n'avez pas de mainfonction dans my_model.pyEnsuite, vous devez appelertf.app.run(my_main_running_function)

  2. vous avez une mainfonction dans my_model.py. (C'est surtout le cas.)

Dernière ligne:

sys.exit(main(sys.argv[:1] + flags_passthrough))

garantit que votre fonction main(argv)ou my_main_running_function(argv)est appelée correctement avec des arguments analysés.

lei du
la source
67
Une pièce manquante du puzzle pour les utilisateurs débutants de Tensorflow: Tensorflow a un mécanisme intégré de gestion des indicateurs de ligne de commande. Vous pouvez définir vos indicateurs comme tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.'), puis si vous l'utilisez, les tf.app.run()choses seront configurées pour que vous puissiez accéder globalement aux valeurs passées des indicateurs que vous avez définis, comme tf.flags.FLAGS.batch_sizepartout où vous en avez besoin dans votre code.
isarandi
1
C'est la meilleure réponse des trois (actuels) à mon avis. Il explique le "Comment fonctionne tf.app.run ()", tandis que les deux autres réponses disent simplement ce qu'il fait.
Thomas Fauskanger
On dirait que les drapeaux sont gérés par abseillesquels TF doit avoir absorbé abseil.io/docs/python/guides/flags
CpILL
75

C'est juste un wrapper très rapide qui gère l'analyse des indicateurs, puis les distribue à votre propre main. Voir le code .

dga
la source
12
que signifie "gère l'analyse des indicateurs"? Peut-être pourriez-vous ajouter un lien pour informer les débutants de ce que cela signifie?
Pinocchio du
4
Il analyse les arguments de ligne de commande fournis au programme à l'aide du package flags. (qui utilise la bibliothèque standard 'argparse' sous les couvertures, avec quelques wrappers). Il est lié au code auquel j'ai lié dans ma réponse.
dga
1
Dans app.py, que signifient main = main or sys.modules['__main__'].mainet que sys.exit(main(sys.argv[:1] + flags_passthrough))signifient?
hAcKnRoCk
3
cela me semble étrange, pourquoi envelopper la fonction principale dans tout cela si vous pouvez simplement l'appeler directement main()?
Charlie Parker
2
hAcKnRoCk: s'il n'y a pas de main dans le fichier, il utilise à la place ce qui se trouve dans sys.modules [' main '] .main. Le sys.exit signifie exécuter la commande principale ainsi trouvée en utilisant les arguments et tous les indicateurs passés, et quitter avec la valeur de retour de main. @CharlieParker - pour la compatibilité avec les bibliothèques d'applications python existantes de Google telles que gflags et google-apputils. Voir, par exemple, github.com/google/google-apputils
dga
8

Il n'y a rien de spécial tf.app. Ceci est juste un script de point d'entrée générique , qui

Exécute le programme avec une fonction optionnelle «main» et une liste «argv».

Cela n'a rien à voir avec les réseaux de neurones et il appelle simplement la fonction principale, en lui passant des arguments.

Salvador Dali
la source
5

En termes simples, le travail de tf.app.run()consiste à définir d' abord les indicateurs globaux pour une utilisation ultérieure, comme:

from tensorflow.python.platform import flags
f = flags.FLAGS

puis exécutez votre fonction principale personnalisée avec un ensemble d'arguments.

Par exemple, dans la base de code TensorFlow NMT , le tout premier point d'entrée pour l'exécution du programme pour l'apprentissage / l'inférence commence à ce stade (voir le code ci-dessous)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

Après avoir analysé les arguments en utilisant argparse, avec tf.app.run()vous exécutez la fonction "main" qui est définie comme:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

Ainsi, après avoir défini les indicateurs pour une utilisation globale, tf.app.run()lance simplement la mainfonction que vous lui passez avec argvcomme paramètres.

PS: Comme le dit la réponse de Salvador Dali , c'est juste une bonne pratique d'ingénierie logicielle, je suppose, même si je ne suis pas sûr que TensorFlow exécute une exécution optimisée de la mainfonction que celle qui a été exécutée en utilisant CPython normal.

kmario23
la source
2

Le code de Google dépend beaucoup de l'accès aux indicateurs globaux dans les bibliothèques / binaires / scripts python et donc tf.app.run () analyse ces indicateurs pour créer un état global dans la variable FLAGs (ou quelque chose de similaire) puis appelle python main ( ) Comme il se doit.

S'ils n'avaient pas cet appel à tf.app.run (), alors les utilisateurs pourraient oublier de faire l'analyse des FLAG, ce qui empêcherait ces bibliothèques / binaires / scripts d'avoir accès aux FLAG dont ils ont besoin.

Mudit Jain
la source
1

Réponse compatible 2.0 : Si vous souhaitez utiliser tf.app.run()dans Tensorflow 2.0, nous devons utiliser la commande,

tf.compat.v1.app.run()ou vous pouvez utiliser tf_upgrade_v2pour convertir le 1.xcode en 2.0.

Prise en charge de Tensorflow
la source