Comment répertorier toutes les opérations utilisées dans Tensorflow SavedModel?

10

Si tensorflow.saved_model.savej'enregistre mon modèle à l'aide de la fonction au format SavedModel, comment puis-je récupérer les Ops Tensorflow utilisées dans ce modèle par la suite. Comme le modèle peut être restauré, ces opérations sont stockées dans le graphique, ma supposition est dans le saved_model.pbfichier. Si je charge ce protobuf (donc pas le modèle entier), la partie bibliothèque du protobuf les répertorie, mais cela n'est pas documenté et étiqueté comme une fonctionnalité expérimentale pour l'instant. Les modèles créés dans Tensorflow 1.x n'auront pas cette partie.

Alors, quel est un moyen rapide et fiable pour récupérer une liste des opérations utilisées (comme MatchingFilesou WriteFile) à partir d'un modèle au format SavedModel?

En ce moment, je peux geler le tout, comme le tensorflowjs-converterfait. Comme ils vérifient également les opérations prises en charge. Cela ne fonctionne pas actuellement lorsqu'un LSTM est dans le modèle, voir ici . Y a-t-il une meilleure façon de le faire, car les opérations sont définitivement là-dedans?

Un exemple de modèle:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Attendu en sortie tous les Ops, contenant dans ce cas au moins:

  • ReadFilecomme décrit ici
  • ...
sampers
la source
1
Il est difficile de dire exactement ce que vous voulez, ce qui est saved_model.pb, est - il un tf.GraphDefou un SavedModelmessage protobuf? Si vous avez un tf.GraphDefappelé gd, vous pouvez obtenir la liste des opérations utilisées avec sorted(set(n.op for n in gd.node)). Si vous avez un modèle chargé, vous pouvez le faire sorted(set(op.type for op in tf.get_default_graph().get_operations())). Si c'est un SavedModel, vous pouvez l'obtenir tf.GraphDef(par exemple saved_model.meta_graphs[0].graph_def).
jdehesa
Je veux récupérer les ops d'un SavedModel stocké. Alors en effet, la dernière option que vous décrivez. Quelle est la saved_modelvariable dans votre dernier exemple? Résultat tf.saved_model.load('/path/to/model')ou chargement du protobuf du fichier saved_model.pb.
sampers

Réponses:

1

S'il saved_model.pbs'agit d'un SavedModelmessage protobuf, vous obtenez les opérations directement à partir de là. Disons que nous créons un modèle comme suit:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Nous pouvons maintenant trouver les opérations utilisées par ce modèle comme ceci:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin
jdehesa
la source
J'ai essayé quelque chose comme ça, mais malheureusement cela ne correspond pas à ce que j'attends: Disons que j'ai un modèle qui fait ceci: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')Ensuite, le ReadFile Op tel qu'indiqué ici est là, mais n'est pas imprimé.
sampers
1
@sampers J'ai édité la réponse avec un exemple comme vous le suggérez. J'obtiens l' ReadFileopération dans la sortie. Est-il possible que, dans votre cas réel, cette opération ne soit pas entre l'entrée et la sortie du modèle enregistré? Dans ce cas, je pense qu'il pourrait être élagué.
jdehesa
En effet avec le modèle donné cela fonctionne. Malheureusement pour un module fabriqué en tf2, ce n'est pas le cas. Si je crée un tf.Module avec 1 fonction avec une annotation d' file_nameargument @tf.function, contenant les appels que j'ai énumérés dans mon commentaire précédent, il donne la liste suivante:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers
a ajouté un modèle à ma question
sampers
@sampers J'ai mis à jour ma réponse. J'utilisais TF 1.x auparavant, je ne connaissais pas les modifications apportées aux objets de définition de graphique dans TF 2.x, je pense que la réponse couvre maintenant tout dans le modèle enregistré. Je pense que les opérations correspondant à la fonction Python dans laquelle vous avez écrit sont saved_model.meta_graphs[0].graph_def.library.function[0](la node_defcollection dans cet objet fonction).
jdehesa