TensorFlow, pourquoi il y a 3 fichiers après avoir enregistré le modèle?

113

Après avoir lu la documentation , j'ai enregistré un modèle dans TensorFlow, voici mon code de démonstration:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

mais après cela, j'ai trouvé qu'il y avait 3 fichiers

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

Et je ne peux pas restaurer le modèle en restaurant le model.ckptfichier, car il n'y a pas de tel fichier. Voici mon code

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Alors, pourquoi il y a 3 fichiers?

GoingMyWay
la source
2
Avez-vous trouvé comment résoudre ce problème? Comment puis-je recharger le modèle (en utilisant Keras)?
rajkiran

Réponses:

116

Essaye ça:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

La méthode de sauvegarde TensorFlow enregistre trois types de fichiers car elle stocke la structure du graphique séparément des valeurs de variable . Le .metafichier décrit la structure du graphique enregistré, vous devez donc l'importer avant de restaurer le point de contrôle (sinon il ne sait pas à quelles variables les valeurs de point de contrôle enregistrées correspondent).

Vous pouvez également faire ceci:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Même s'il n'y a pas de fichier nommé model.ckpt, vous faites toujours référence au point de contrôle enregistré par ce nom lors de sa restauration. À partir du saver.pycode source :

Les utilisateurs doivent uniquement interagir avec le préfixe spécifié par l'utilisateur ... au lieu d'un chemin d'accès physique.

TK Bartel
la source
1
donc le .index et le .data ne sont pas utilisés? Quand ces 2 fichiers sont-ils alors utilisés?
ajfbiw.s
26
@ ajfbiw.s .meta stocke la structure du graphe, .data stocke les valeurs de chaque variable dans le graphe, .index identifie le point de contrôle. Donc, dans l'exemple ci-dessus: import_meta_graph utilise le .meta, et saver.restore utilise les .data et .index
TK Bartel
Oh je vois. Merci.
ajfbiw.s
1
Y a-t-il une chance que vous ayez enregistré le modèle avec une version de TensorFlow différente de celle que vous utilisez pour le charger? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel
5
Est-ce que quelqu'un sait ce que cela signifie 00000et les 00001chiffres? dans le variables.data-?????-of-?????fichier
Ivan Talalaev
55
  • meta file : décrit la structure graphique enregistrée, inclut GraphDef, SaverDef, etc. puis appliquez tf.train.import_meta_graph('/tmp/model.ckpt.meta'), restaurera Saveret Graph.

  • fichier d'index : c'est une table immuable de chaîne de caractères (tensorflow :: table :: Table). Chaque clé est le nom d'un tenseur et sa valeur est un BundleEntryProto sérialisé. Chaque BundleEntryProto décrit les métadonnées d'un tenseur: lequel des fichiers «données» contient le contenu d'un tenseur, le décalage dans ce fichier, la somme de contrôle, certaines données auxiliaires, etc.

  • fichier de données : il s'agit de la collection TensorBundle, enregistrez les valeurs de toutes les variables.

Guangcong Liu
la source
J'ai le fichier pb que j'ai pour la classification des images. Puis-je l'utiliser pour la classification vidéo en temps réel?
Pouvez-vous s'il vous plaît me faire savoir, en utilisant Keras 2, comment charger le modèle s'il est enregistré en 3 fichiers?
rajkiran
5

Je restaure les incorporations de mots entraînés à partir du didacticiel tensorflow de Word2Vec.

Si vous avez créé plusieurs points de contrôle:

par exemple, les fichiers créés ressemblent à ceci

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

essaye ça

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

lors de l'appel de restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")
Steven Wong
la source
Que signifie "00000-of-00001" dans "model.ckpt-55695.data-00000-of-00001"?
hafiz031
0

Si vous avez formé un CNN avec abandon, par exemple, vous pouvez le faire:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
happy_sisyphus
la source