Après avoir formé un modèle dans Tensorflow:
- Comment enregistrez-vous le modèle formé?
- Comment restaurer ultérieurement ce modèle enregistré?
python
tensorflow
machine-learning
model
mathetes
la source
la source
Réponses:
Documents
tutoriel exhaustif et utile -> https://www.tensorflow.org/guide/saved_model
Guide détaillé de Keras pour enregistrer des modèles -> https://www.tensorflow.org/guide/keras/save_and_serialize
De la documentation:
sauver
Restaurer
Tensorflow 2
Ceci est encore en version bêta, donc je déconseille pour l'instant. Si vous voulez toujours emprunter cette voie, voici le
tf.saved_model
guide d'utilisationTensorflow <2
simple_save
Beaucoup de bonne réponse, pour être complet j'ajouterai mes 2 cents: simple_save . Également un exemple de code autonome utilisant l'
tf.data.Dataset
API.Python 3; Tensorflow 1.14
Restauration:
Exemple autonome
Article de blog original
Le code suivant génère des données aléatoires pour la démonstration.
Dataset
puis sonIterator
. Nous obtenons le tenseur généré par l'itérateur, appeléinput_tensor
qui servira d'entrée à notre modèle.input_tensor
: d'un RNN bidirectionnel basé sur GRU suivi d'un classificateur dense. Parce que pourquoi pas.softmax_cross_entropy_with_logits
, optimisée avecAdam
. Après 2 époques (de 2 lots chacune), nous sauvegardons le modèle "entraîné" avectf.saved_model.simple_save
. Si vous exécutez le code tel quel, le modèle sera enregistré dans un dossier appelésimple/
dans votre répertoire de travail actuel.tf.saved_model.loader.load
. Nous récupérons les espaces réservés et les logits avecgraph.get_tensor_by_name
et l'Iterator
opération d'initialisation avecgraph.get_operation_by_name
.Code:
Cela imprimera:
la source
tf.contrib.layers
?[n.name for n in graph2.as_graph_def().node]
. Comme le dit la documentation, la sauvegarde simple vise à simplifier l'interaction avec le service tensorflow, c'est le point des arguments; d'autres variables sont cependant toujours restaurées, sinon l'inférence ne se produirait pas. Saisissez simplement vos variables d'intérêt comme je l'ai fait dans l'exemple. Consultez la documentationglobal_step
argument, si vous vous arrêtez puis essayez de reprendre la formation, il pensera que vous êtes une étape. Cela gâchera au moins vos visualisations de tensorboardJ'améliore ma réponse pour ajouter plus de détails sur la sauvegarde et la restauration des modèles.
Dans (et après) la version 0.11 de Tensorflow :
Enregistrez le modèle:
Restaurer le modèle:
Ceci et quelques cas d'utilisation plus avancés ont été très bien expliqués ici.
Un tutoriel complet rapide pour enregistrer et restaurer les modèles Tensorflow
la source
:0
aux noms?Dans (et après) la version 0.11.0RC1 de TensorFlow, vous pouvez enregistrer et restaurer votre modèle directement en appelant
tf.train.export_meta_graph
ettf.train.import_meta_graph
selon https://www.tensorflow.org/programmers_guide/meta_graph .Enregistrez le modèle
Restaurer le modèle
la source
<built-in function TF_Run> returned a result with an error set
tf.get_variable_scope().reuse_variables()
suivi devar = tf.get_variable("varname")
. Cela me donne l'erreur: "ValueError: la variable varname n'existe pas ou n'a pas été créée avec tf.get_variable ()." Pourquoi? Cela ne devrait-il pas être possible?Pour la version TensorFlow <0.11.0RC1:
Les points de contrôle enregistrés contiennent des valeurs pour les
Variable
s dans votre modèle, pas le modèle / graphique lui-même, ce qui signifie que le graphique doit être le même lorsque vous restaurez le point de contrôle.Voici un exemple de régression linéaire où il y a une boucle d'apprentissage qui enregistre les points de contrôle des variables et une section d'évaluation qui restaurera les variables enregistrées lors d'une exécution précédente et calculera les prédictions. Bien sûr, vous pouvez également restaurer des variables et continuer la formation si vous le souhaitez.
Voici les documents pour
Variable
s, qui couvrent la sauvegarde et la restauration. Et voici les documents pour leSaver
.la source
batch_x
doit être? Binaire? Tableau Numpy?undefined
. Pouvez-vous me dire quelle est la définition de FLAGS pour ce code. @RyanSepassiMon environnement: Python 3.6, Tensorflow 1.3.0
Bien qu'il y ait eu de nombreuses solutions, la plupart d'entre elles sont basées sur
tf.train.Saver
. Lorsque nous chargeons un.ckpt
sauvé parSaver
, nous devons redéfinir soit le réseau tensorflow ou utiliser un bizarre nom et dur rappeler, par exemple'placehold_0:0'
,'dense/Adam/Weight:0'
. Ici, je recommande d'utilisertf.saved_model
, un exemple le plus simple donné ci-dessous, vous pouvez en apprendre plus sur Serving a TensorFlow Model :Enregistrez le modèle:
Chargez le modèle:
la source
Le modèle comporte deux parties, la définition du modèle, enregistrée par
Supervisor
commegraph.pbtxt
dans le répertoire du modèle et les valeurs numériques des tenseurs, enregistrées dans des fichiers de point de contrôle commemodel.ckpt-1003418
.La définition du modèle peut être restaurée à l'aide
tf.import_graph_def
et les poids sont restaurés à l'aideSaver
.Cependant,
Saver
utilise une collection spéciale contenant une liste de variables attachées au graphique du modèle, et cette collection n'est pas initialisée à l'aide d'import_graph_def, vous ne pouvez donc pas utiliser les deux ensemble pour le moment (c'est sur notre feuille de route à corriger). Pour l'instant, vous devez utiliser l'approche de Ryan Sepassi - construire manuellement un graphique avec des noms de nœuds identiques, et utiliserSaver
pour y charger les poids.(Alternativement, vous pouvez le pirater en utilisant en utilisant
import_graph_def
, en créant des variables manuellement et en utilisanttf.add_to_collection(tf.GraphKeys.VARIABLES, variable)
pour chaque variable, puis en utilisantSaver
)la source
Vous pouvez également utiliser cette méthode plus facilement.
Étape 1: initialisez toutes vos variables
Étape 2: enregistrez la session dans le modèle
Saver
et enregistrez-laÉtape 3: restaurer le modèle
Étape 4: vérifiez votre variable
Lors de l'exécution dans une instance python différente, utilisez
la source
Dans la plupart des cas, l'enregistrement et la restauration à partir du disque à l'aide de
tf.train.Saver
est votre meilleure option:Vous pouvez également enregistrer / restaurer la structure graphique elle-même (voir la documentation MetaGraph pour plus de détails). Par défaut, l'
Saver
enregistre la structure du graphique dans un.meta
fichier. Vous pouvez appelerimport_meta_graph()
pour le restaurer. Il restaure la structure du graphique et renvoie unSaver
que vous pouvez utiliser pour restaurer l'état du modèle:Cependant, il y a des cas où vous avez besoin de quelque chose de beaucoup plus rapide. Par exemple, si vous implémentez un arrêt anticipé, vous souhaitez enregistrer les points de contrôle chaque fois que le modèle s'améliore pendant la formation (mesuré sur l'ensemble de validation), puis s'il n'y a pas de progrès pendant un certain temps, vous souhaitez revenir au meilleur modèle. Si vous enregistrez le modèle sur le disque chaque fois qu'il s'améliore, cela ralentit considérablement la formation. L'astuce consiste à enregistrer les états des variables en mémoire , puis à les restaurer plus tard:
Une explication rapide: lorsque vous créez une variable
X
, TensorFlow crée automatiquement une opération d'affectationX/Assign
pour définir la valeur initiale de la variable. Au lieu de créer des espaces réservés et des opérations d'affectation supplémentaires (ce qui rendrait le graphique désordonné), nous utilisons simplement ces opérations d'affectation existantes. La première entrée de chaque affectation op est une référence à la variable qu'elle est censée initialiser, et la deuxième entrée (assign_op.inputs[1]
) est la valeur initiale. Donc, pour définir la valeur que nous voulons (au lieu de la valeur initiale), nous devons utiliser afeed_dict
et remplacer la valeur initiale. Oui, TensorFlow vous permet de fournir une valeur pour n'importe quelle opération, pas seulement pour les espaces réservés, donc cela fonctionne très bien.la source
Comme l'a dit Yaroslav, vous pouvez pirater la restauration à partir d'un graph_def et d'un point de contrôle en important le graphique, en créant manuellement des variables, puis en utilisant un économiseur.
J'ai implémenté cela pour mon usage personnel, donc je pensais partager le code ici.
Lien: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(Il s'agit bien sûr d'un hack, et rien ne garantit que les modèles enregistrés de cette manière resteront lisibles dans les futures versions de TensorFlow.)
la source
S'il s'agit d'un modèle enregistré en interne, vous spécifiez simplement un restaurateur pour toutes les variables comme
et l'utiliser pour restaurer des variables dans une session en cours:
Pour le modèle externe, vous devez spécifier le mappage entre ses noms de variables et vos noms de variables. Vous pouvez afficher les noms des variables de modèle à l'aide de la commande
Le script inspect_checkpoint.py se trouve dans le dossier './tensorflow/python/tools' de la source Tensorflow.
Pour spécifier le mappage, vous pouvez utiliser mon Tensorflow-Worklab , qui contient un ensemble de classes et de scripts pour former et recycler différents modèles. Il comprend un exemple de recyclage des modèles ResNet, situé ici
la source
all_variables()
est désormais obsolèteVoici ma solution simple pour les deux cas de base différents selon que vous souhaitez charger le graphique à partir d'un fichier ou le créer pendant l'exécution.
Cette réponse est valable pour Tensorflow 0.12+ (y compris 1.0).
Reconstruire le graphique en code
Économie
Chargement
Chargement également du graphique à partir d'un fichier
Lorsque vous utilisez cette technique, assurez-vous que toutes vos couches / variables ont explicitement défini des noms uniques.Sinon, Tensorflow rendra les noms uniques eux-mêmes et ils seront donc différents des noms stockés dans le fichier. Ce n'est pas un problème dans la technique précédente, car les noms sont "mutilés" de la même manière lors du chargement et de l'enregistrement.
Économie
Chargement
la source
global_step
variable et les moyennes mobiles de normalisation par lots sont des variables non entraînables, mais les deux valent vraiment la peine d'être sauvegardées. En outre, vous devez distinguer plus clairement la construction du graphique de l'exécution de la session, par exempleSaver(...).save()
, vous créerez de nouveaux nœuds chaque fois que vous l'exécuterez. Probablement pas ce que vous voulez. Et il y a plus ...: /Vous pouvez également consulter des exemples dans TensorFlow / skflow , qui propose des méthodes
save
etrestore
qui peuvent vous aider à gérer facilement vos modèles. Il possède des paramètres que vous pouvez également contrôler à quelle fréquence vous souhaitez sauvegarder votre modèle.la source
Si vous utilisez tf.train.MonitoredTrainingSession comme session par défaut, vous n'avez pas besoin d'ajouter de code supplémentaire pour enregistrer / restaurer des choses. Passez simplement un nom de répertoire de point de contrôle au constructeur de MonitoredTrainingSession, il utilisera des hooks de session pour les gérer.
la source
Toutes les réponses ici sont excellentes, mais je veux ajouter deux choses.
Tout d'abord, pour développer la réponse de @ user7505159, le "./" peut être important à ajouter au début du nom de fichier que vous restaurez.
Par exemple, vous pouvez enregistrer un graphique sans "./" dans le nom de fichier comme ceci:
Mais pour restaurer le graphique, vous devrez peut-être ajouter un "./" au nom_fichier:
Vous n'aurez pas toujours besoin du "./", mais cela peut poser des problèmes en fonction de votre environnement et de la version de TensorFlow.
Il faut également mentionner que le
sess.run(tf.global_variables_initializer())
peut être important avant de restaurer la session.Si vous recevez une erreur concernant les variables non initialisées lorsque vous essayez de restaurer une session enregistrée, assurez-vous de l'inclure
sess.run(tf.global_variables_initializer())
avant lasaver.restore(sess, save_file)
ligne. Cela peut vous éviter des maux de tête.la source
Comme décrit dans le numéro 6255 :
au lieu de
la source
Selon la nouvelle version Tensorflow,
tf.train.Checkpoint
est le moyen préférable d'enregistrer et de restaurer un modèle:Voici un exemple:
Plus d'informations et d'exemple ici.
la source
Pour tensorflow 2.0 , c'est aussi simple que
Restaurer:
la source
tf.keras Enregistrement de modèle avec
TF2.0
Je vois d'excellentes réponses pour enregistrer des modèles à l'aide de TF1.x. Je veux fournir quelques conseils supplémentaires pour enregistrer des
tensorflow.keras
modèles, ce qui est un peu compliqué car il existe de nombreuses façons de sauvegarder un modèle.Ici, je fournis un exemple d'enregistrement d'un
tensorflow.keras
modèle dans unmodel_path
dossier sous le répertoire actuel. Cela fonctionne bien avec le tensorflow le plus récent (TF2.0). Je mettrai à jour cette description en cas de changement dans un avenir proche.Enregistrement et chargement du modèle entier
Enregistrement et chargement des poids du modèle uniquement
Si vous souhaitez enregistrer uniquement les poids du modèle, puis charger des poids pour restaurer le modèle,
Enregistrement et restauration à l'aide du rappel de point de contrôle keras
modèle d'enregistrement avec des mesures personnalisées
Enregistrement du modèle de keras avec des opérations personnalisées
Lorsque nous avons des opérations personnalisées comme dans le cas suivant (
tf.tile
), nous devons créer une fonction et encapsuler avec une couche Lambda. Sinon, le modèle ne peut pas être enregistré.Je pense avoir couvert quelques-unes des nombreuses façons de sauvegarder le modèle tf.keras. Cependant, il existe de nombreuses autres façons. Veuillez commenter ci-dessous si vous voyez que votre cas d'utilisation n'est pas couvert ci-dessus. Merci!
la source
Utilisez tf.train.Saver pour enregistrer un modèle, remerber, vous devez spécifier la var_list, si vous voulez réduire la taille du modèle. La val_list peut être tf.trainable_variables ou tf.global_variables.
la source
Vous pouvez enregistrer les variables dans le réseau en utilisant
Pour restaurer le réseau pour une réutilisation ultérieure ou dans un autre script, utilisez:
Les points importants:
sess
doit être le même entre la première et la dernière exécution (structure cohérente).saver.restore
a besoin du chemin du dossier des fichiers enregistrés, pas d'un chemin de fichier individuel.la source
Où que vous souhaitiez enregistrer le modèle,
Assurez-vous que vous
tf.Variable
avez tous des noms, car vous souhaiterez peut-être les restaurer ultérieurement en utilisant leurs noms. Et où vous voulez prédire,Assurez-vous que l'économiseur s'exécute dans la session correspondante. N'oubliez pas que si vous utilisez le
tf.train.latest_checkpoint('./')
, seul le dernier point de contrôle sera utilisé.la source
Je suis sur la version:
Un moyen simple est
Sauver:
Restaurer:
la source
Pour tensorflow-2.0
c'est très simple.
ENREGISTRER
RESTAURER
la source
Suite à la réponse de @Vishnuvardhan Janapati, voici une autre façon d'enregistrer et de recharger le modèle avec une couche / métrique / perte personnalisée sous TensorFlow 2.0.0
De cette façon, une fois que vous avez exécuté de tels codes et enregistré votre modèle avec
tf.keras.models.save_model
oumodel.save
ouModelCheckpoint
rappel, vous pouvez recharger votre modèle sans avoir besoin d'objets personnalisés précis, aussi simple quela source
Dans la nouvelle version de tensorflow 2.0, le processus de sauvegarde / chargement d'un modèle est beaucoup plus facile. En raison de l'implémentation de l'API Keras, une API de haut niveau pour TensorFlow.
Pour enregistrer un modèle: consultez la documentation pour référence: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model
Pour charger un modèle:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model
la source
Voici un exemple simple utilisant le format Tensorflow 2.0 SavedModel (qui est le format recommandé, selon les documents ) pour un classificateur de jeu de données MNIST simple, en utilisant l'API fonctionnelle Keras sans trop de fantaisie:
Qu'est-ce que c'est
serving_default
?Il s'agit du nom de la signature par défaut du tag que vous avez sélectionné (dans ce cas, le
serve
tag par défaut a été sélectionné). Également, ici explique comment trouver les balises et les signatures d'un modèle à l'aidesaved_model_cli
.Avertissements
Ce n'est qu'un exemple de base si vous voulez simplement le mettre en service, mais ce n'est en aucun cas une réponse complète - je pourrai peut-être le mettre à jour à l'avenir. Je voulais juste donner un exemple simple en utilisant le
SavedModel
TF 2.0 car je n'en ai vu aucun, même si simple, nulle part.@ La réponse de Tom est un exemple de SavedModel, mais cela ne fonctionnera pas sur Tensorflow 2.0, car malheureusement il y a des changements de rupture.
@ La réponse de Vishnuvardhan Janapati dit TF 2.0, mais ce n'est pas pour le format SavedModel.
la source