D'après ce que j'ai rassemblé jusqu'à présent, il existe plusieurs façons de vider un graphique TensorFlow dans un fichier, puis de le charger dans un autre programme, mais je n'ai pas été en mesure de trouver des exemples / informations clairs sur leur fonctionnement. Ce que je sais déjà, c'est ceci:
- Enregistrez les variables du modèle dans un fichier de point de contrôle (.ckpt) en utilisant a
tf.train.Saver()
et restaurez-les plus tard ( source ) - Enregistrer un modèle dans un fichier .pb et le recharger en utilisant
tf.train.write_graph()
ettf.import_graph_def()
( source ) - Charger un modèle à partir d'un fichier .pb, le recycler et le vider dans un nouveau fichier .pb à l'aide de Bazel ( source )
- Figer le graphique pour enregistrer le graphique et les poids ensemble ( source )
- Utilisez
as_graph_def()
pour enregistrer le modèle, et pour les poids / variables, les mapper en constantes ( source )
Cependant, je n'ai pas pu clarifier plusieurs questions concernant ces différentes méthodes:
- En ce qui concerne les fichiers de point de contrôle, enregistrent-ils uniquement les poids entraînés d'un modèle? Les fichiers de point de contrôle peuvent-ils être chargés dans un nouveau programme et être utilisés pour exécuter le modèle, ou servent-ils simplement à enregistrer les poids dans un modèle à un certain moment / étape?
- Concernant
tf.train.write_graph()
, les poids / variables sont-ils également enregistrés? - En ce qui concerne Bazel, peut-il uniquement enregistrer dans / charger des fichiers .pb pour le recyclage? Existe-t-il une simple commande Bazel juste pour vider un graphique dans un .pb?
- En ce qui concerne le gel, un graphique figé peut-il être chargé en utilisant
tf.import_graph_def()
? - La démo Android pour TensorFlow se charge dans le modèle Inception de Google à partir d'un fichier .pb. Si je voulais remplacer mon propre fichier .pb, comment procéderais-je? Aurais-je besoin de changer de code / méthodes natifs?
- En général, quelle est exactement la différence entre toutes ces méthodes? Ou plus largement, quelle est la différence entre
as_graph_def()
/.ckpt/.pb?
En bref, ce que je recherche, c'est une méthode pour enregistrer à la fois un graphique (comme dans, les différentes opérations et autres) et ses poids / variables dans un fichier, qui peut ensuite être utilisé pour charger le graphique et les poids dans un autre programme , à utiliser (pas nécessairement de poursuite / de recyclage).
La documentation sur ce sujet n'est pas très simple, donc toute réponse / information serait grandement appréciée.
la source
Réponses:
Il existe de nombreuses façons d'aborder le problème de l'enregistrement d'un modèle dans TensorFlow, ce qui peut le rendre un peu déroutant. En prenant chacune de vos sous-questions à tour de rôle:
Les fichiers de point de contrôle (produits par exemple en appelant
saver.save()
untf.train.Saver
objet) ne contiennent que les poids et toutes les autres variables définies dans le même programme. Pour les utiliser dans un autre programme, vous devez recréer la structure de graphe associée (par exemple en exécutant du code pour le reconstruire, ou en appelanttf.import_graph_def()
), qui indique à TensorFlow quoi faire avec ces pondérations. Notez que l'appelsaver.save()
produit également un fichier contenant aMetaGraphDef
, qui contient un graphique et des détails sur la façon d'associer les poids d'un point de contrôle à ce graphique. Voir le tutoriel pour plus de détails.tf.train.write_graph()
n'écrit que la structure du graphe; pas les poids.Bazel n'est pas lié à la lecture ou à l'écriture de graphiques TensorFlow. (Peut-être ai-je mal compris votre question: n'hésitez pas à la clarifier dans un commentaire.)
Un graphique figé peut être chargé en utilisant
tf.import_graph_def()
. Dans ce cas, les pondérations sont (généralement) intégrées dans le graphique, vous n'avez donc pas besoin de charger un point de contrôle séparé.Le principal changement serait de mettre à jour les noms du (des) tenseur (s) qui sont introduits dans le modèle, et les noms du (des) tenseur (s) qui sont extraits du modèle. Dans la démo Android de TensorFlow, cela correspondrait aux chaînes
inputName
etoutputName
transmises àTensorFlowClassifier.initializeTensorFlow()
.Il
GraphDef
s'agit de la structure du programme, qui ne change généralement pas au cours du processus de formation. Le point de contrôle est un instantané de l'état d'un processus de formation, qui change généralement à chaque étape du processus de formation. En conséquence, TensorFlow utilise différents formats de stockage pour ces types de données, et l'API de bas niveau fournit différentes manières de les enregistrer et de les charger. Les bibliothèques de niveau supérieur, telles que lesMetaGraphDef
bibliothèques, Keras et skflow, s'appuient sur ces mécanismes pour fournir des moyens plus pratiques pour enregistrer et restaurer un modèle entier.la source
tf.train.write_graph()
puis l'exécuter?GraphDef
enregistré partf.train.write_graph()
, vous devez également vous souvenir des noms des tenseurs que vous souhaitez alimenter et récupérer lors de l'exécution du graphe (point 5 ci-dessus).Vous pouvez essayer le code suivant:
la source