TensorFlow sauvegarde / chargement d'un graphique à partir d'un fichier

98

D'après ce que j'ai rassemblé jusqu'à présent, il existe plusieurs façons de vider un graphique TensorFlow dans un fichier, puis de le charger dans un autre programme, mais je n'ai pas été en mesure de trouver des exemples / informations clairs sur leur fonctionnement. Ce que je sais déjà, c'est ceci:

  1. Enregistrez les variables du modèle dans un fichier de point de contrôle (.ckpt) en utilisant a tf.train.Saver()et restaurez-les plus tard ( source )
  2. Enregistrer un modèle dans un fichier .pb et le recharger en utilisant tf.train.write_graph()et tf.import_graph_def()( source )
  3. Charger un modèle à partir d'un fichier .pb, le recycler et le vider dans un nouveau fichier .pb à l'aide de Bazel ( source )
  4. Figer le graphique pour enregistrer le graphique et les poids ensemble ( source )
  5. Utilisez as_graph_def()pour enregistrer le modèle, et pour les poids / variables, les mapper en constantes ( source )

Cependant, je n'ai pas pu clarifier plusieurs questions concernant ces différentes méthodes:

  1. En ce qui concerne les fichiers de point de contrôle, enregistrent-ils uniquement les poids entraînés d'un modèle? Les fichiers de point de contrôle peuvent-ils être chargés dans un nouveau programme et être utilisés pour exécuter le modèle, ou servent-ils simplement à enregistrer les poids dans un modèle à un certain moment / étape?
  2. Concernant tf.train.write_graph(), les poids / variables sont-ils également enregistrés?
  3. En ce qui concerne Bazel, peut-il uniquement enregistrer dans / charger des fichiers .pb pour le recyclage? Existe-t-il une simple commande Bazel juste pour vider un graphique dans un .pb?
  4. En ce qui concerne le gel, un graphique figé peut-il être chargé en utilisant tf.import_graph_def()?
  5. La démo Android pour TensorFlow se charge dans le modèle Inception de Google à partir d'un fichier .pb. Si je voulais remplacer mon propre fichier .pb, comment procéderais-je? Aurais-je besoin de changer de code / méthodes natifs?
  6. En général, quelle est exactement la différence entre toutes ces méthodes? Ou plus largement, quelle est la différence entre as_graph_def()/.ckpt/.pb?

En bref, ce que je recherche, c'est une méthode pour enregistrer à la fois un graphique (comme dans, les différentes opérations et autres) et ses poids / variables dans un fichier, qui peut ensuite être utilisé pour charger le graphique et les poids dans un autre programme , à utiliser (pas nécessairement de poursuite / de recyclage).

La documentation sur ce sujet n'est pas très simple, donc toute réponse / information serait grandement appréciée.

Technicolor
la source
2
L'API la plus récente / la plus complète est le méta-graphe, qui vous donnera un moyen de sauvegarder les trois à la fois - 1) graphe 2) valeurs de paramètres 3) collections: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Réponses:

80

Il existe de nombreuses façons d'aborder le problème de l'enregistrement d'un modèle dans TensorFlow, ce qui peut le rendre un peu déroutant. En prenant chacune de vos sous-questions à tour de rôle:

  1. Les fichiers de point de contrôle (produits par exemple en appelant saver.save()un tf.train.Saverobjet) ne contiennent que les poids et toutes les autres variables définies dans le même programme. Pour les utiliser dans un autre programme, vous devez recréer la structure de graphe associée (par exemple en exécutant du code pour le reconstruire, ou en appelant tf.import_graph_def()), qui indique à TensorFlow quoi faire avec ces pondérations. Notez que l'appel saver.save()produit également un fichier contenant a MetaGraphDef, qui contient un graphique et des détails sur la façon d'associer les poids d'un point de contrôle à ce graphique. Voir le tutoriel pour plus de détails.

  2. tf.train.write_graph()n'écrit que la structure du graphe; pas les poids.

  3. Bazel n'est pas lié à la lecture ou à l'écriture de graphiques TensorFlow. (Peut-être ai-je mal compris votre question: n'hésitez pas à la clarifier dans un commentaire.)

  4. Un graphique figé peut être chargé en utilisant tf.import_graph_def(). Dans ce cas, les pondérations sont (généralement) intégrées dans le graphique, vous n'avez donc pas besoin de charger un point de contrôle séparé.

  5. Le principal changement serait de mettre à jour les noms du (des) tenseur (s) qui sont introduits dans le modèle, et les noms du (des) tenseur (s) qui sont extraits du modèle. Dans la démo Android de TensorFlow, cela correspondrait aux chaînes inputNameet outputNametransmises à TensorFlowClassifier.initializeTensorFlow().

  6. Il GraphDefs'agit de la structure du programme, qui ne change généralement pas au cours du processus de formation. Le point de contrôle est un instantané de l'état d'un processus de formation, qui change généralement à chaque étape du processus de formation. En conséquence, TensorFlow utilise différents formats de stockage pour ces types de données, et l'API de bas niveau fournit différentes manières de les enregistrer et de les charger. Les bibliothèques de niveau supérieur, telles que les MetaGraphDefbibliothèques, Keras et skflow, s'appuient sur ces mécanismes pour fournir des moyens plus pratiques pour enregistrer et restaurer un modèle entier.

mrry
la source
Cela signifie-t-il que la documentation de l'API C ++ ment, quand elle dit que vous pouvez charger le graphique enregistré avec tf.train.write_graph()puis l'exécuter?
mnicky
2
La documentation de l'API C ++ ne ment pas, mais il manque quelques détails. Le détail le plus important est qu'en plus de celui GraphDefenregistré par tf.train.write_graph(), vous devez également vous souvenir des noms des tenseurs que vous souhaitez alimenter et récupérer lors de l'exécution du graphe (point 5 ci-dessus).
mrry
@mrry: J'ai essayé d'utiliser l'exemple de tensorflows DeepDream. mais il semble qu'il ait besoin de modèles pré-entraînés au format pb! J'ai exécuté l'exemple Cifar10, mais il ne crée que des points de contrôle! Je n'ai pas pu trouver de fichiers pb ou quoi que ce soit! comment puis-je convertir mes points de contrôle au format pb utilisé par l'exemple de deepdream?
Rika
2
@ Coderx7 Je pense vraiment que vous ne pouvez pas convertir un .ckpt en .pb puisque le point de contrôle ne contient que les poids et les variables et ne sait rien de la structure du graphique
davidivad
1
existe-t-il un code simple pour charger un fichier .pb et l'exécuter?
Kong
1

Vous pouvez essayer le code suivant:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Srihari Humbarwadi
la source