Problème avec l'exécution de object_detection_tutorial TypeError: load () manque 2 arguments positionnels requis

11

Je suis assez nouveau sur tensorflow et j'essaie d'exécuter object_detection_tutorial. Je reçois TypeErrror et je ne sais pas comment le réparer.

C'est la fonction load_model qui manque 2 arguments:

tags: ensemble de balises de chaîne pour identifier le MetaGraphDef requis. Ceux-ci doivent correspondre aux balises utilisées lors de l'enregistrement des variables à l'aide de l'API save () SavedModel.

export_dir: répertoire dans lequel se trouvent le tampon de protocole SavedModel et les variables à charger.

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model
WARNING:tensorflow:From <ipython-input-9-f8a3c92a04a4>:11: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-e10c73a22cc9> in <module>
      1 model_name = 'ssd_mobilenet_v1_coco_2017_11_17'
----> 2 detection_model = load_model(model_name)

<ipython-input-9-f8a3c92a04a4> in load_model(model_name)
      9   model_dir = pathlib.Path(model_dir)/"saved_model"
     10 
---> 11   model = tf.saved_model.load(str(model_dir))
     12   model = model.signatures['serving_default']
     13 

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

TypeError: load() missing 2 required positional arguments: 'tags' and 'export_dir'

Pouvez-vous m'aider à résoudre ce problème et à exécuter mon premier détecteur d'objets: D?

Dominik
la source

Réponses:

14

J'ai eu le même problème et j'essaie de résoudre ce problème depuis 1 semaine maintenant. Je suppose que la solution devrait être la suivante;

model = tf.compat.v2.saved_model.load(str(model_dir), None)

Plus de détails seraient (sur le site officiel );

Chargez un SavedModel depuis export_dir.

tf.saved_model.load(
    export_dir,
    tags=None
)

Alias:

tf.compat.v1.saved_model.load_v2

tf.compat.v2.saved_model.load
Onur Baskin
la source
1
J'ai utilisé votre solution et j'ai eu une autre erreur. J'ai mis à jour tout ce que je pouvais et ça marche! J'ai également eu une erreur avec pathlib n'étant pas installé.
Dominik
@Dominik pouvez-vous être plus précis? peut-être que je peux aider parce que cette aventure tensorflow m'a amené à résoudre beaucoup de problèmes: D
Onur Baskin
4
@OnurBaskin Il y a une erreur plus tard: L'argument TypeError: int () doit être une chaîne, un objet de type octets ou un nombre, pas 'Tensor'
kaitsu
@Dominik Je suppose que c'est votre version Tensorflow. Ce devrait être la version 2.0 (stable). Voici le lien vers la question que j'ai posée peut-être que vous rencontrez l'erreur exacte. Recherchez également toute ancienne importation qui nécessite «compat.v1». plus tard, vous devriez avoir beaucoup plus d'erreurs, mais c'est ainsi que vous migrez un ancien code.
Onur Baskin
@OnurBaskin Je suis assez confus. Je pensais que l'API de détection d'objets n'était compatible qu'avec les versions de TensorFlow 1.
Biiiiiird
0

J'ai deviné que c'était un problème de branche et utiliser la branche tf_2_1_reference a fait l'affaire pour moi:

igian@iGians-MBP models % git checkout tf_2_1_reference
M   research/object_detection/object_detection_tutorial.ipynb
Branch 'tf_2_1_reference' set up to track remote branch 'tf_2_1_reference' from 'origin'.
Switched to a new branch 'tf_2_1_reference'
igians@iGians-MBP models % jupyter notebook

Puis exécuté chaque cellule jupiter du tutoriel comme un bon débutant!

Voici la branche que j'ai utilisée: https://github.com/tensorflow/models/tree/tf_2_1_reference

iGian
la source