Je cherchais des moyens alternatifs pour enregistrer un modèle entraîné dans PyTorch. Jusqu'à présent, j'ai trouvé deux alternatives.
- torch.save () pour enregistrer un modèle et torch.load () pour charger un modèle.
- model.state_dict () pour enregistrer un modèle entraîné et model.load_state_dict () pour charger le modèle enregistré.
Je suis tombé sur cette discussion où l'approche 2 est recommandée par rapport à l'approche 1.
Ma question est la suivante: pourquoi la deuxième approche est-elle préférée? Est-ce uniquement parce que les modules torch.nn ont ces deux fonctions et que nous sommes encouragés à les utiliser?
python
serialization
deep-learning
pytorch
tensor
Wasi Ahmad
la source
la source
torch.save(model, f)
ettorch.save(model.state_dict(), f)
. Les fichiers enregistrés ont la même taille. Maintenant je suis confus. De plus, j'ai trouvé l'utilisation de pickle pour enregistrer model.state_dict () extrêmement lente. Je pense que le meilleur moyen est d'utilisertorch.save(model.state_dict(), f)
puisque vous gérez la création du modèle et que la torche gère le chargement des poids du modèle, éliminant ainsi les problèmes possibles. Référence: discuss.pytorch.org/t/saving-torch-models/838/4pickle
?Réponses:
J'ai trouvé cette page sur leur dépôt github, je vais simplement coller le contenu ici.
Approche recommandée pour enregistrer un modèle
Il existe deux approches principales pour la sérialisation et la restauration d'un modèle.
Le premier (recommandé) enregistre et charge uniquement les paramètres du modèle:
Puis plus tard:
La seconde sauvegarde et charge l'ensemble du modèle:
Puis plus tard:
Cependant, dans ce cas, les données sérialisées sont liées aux classes spécifiques et à la structure de répertoire exacte utilisée, de sorte qu'elles peuvent se briser de diverses manières lorsqu'elles sont utilisées dans d'autres projets ou après de sérieux refactors.
la source
pickle
?Cela dépend de ce que vous voulez faire.
Cas n ° 1: enregistrez le modèle pour l'utiliser vous-même pour l'inférence : vous enregistrez le modèle, vous le restaurez, puis vous changez le modèle en mode d'évaluation. Ceci est fait parce que vous avez en général
BatchNorm
et desDropout
couches qui sont par défaut en mode train sur la construction:Cas n ° 2: Enregistrer le modèle pour reprendre l'entraînement plus tard : Si vous devez continuer à entraîner le modèle que vous êtes sur le point d'enregistrer, vous devez enregistrer plus que le modèle. Vous devez également enregistrer l'état de l'optimiseur, les époques, le score, etc. Vous le feriez comme ceci:
Pour reprendre l'entraînement, vous feriez des choses comme:,
state = torch.load(filepath)
puis, pour restaurer l'état de chaque objet individuel, quelque chose comme ceci:Puisque vous reprenez l'entraînement, NE PAS appeler
model.eval()
une fois que vous avez restauré les états lors du chargement.Cas n ° 3: Modèle à utiliser par quelqu'un d'autre sans accès à votre code : Dans Tensorflow, vous pouvez créer un
.pb
fichier qui définit à la fois l'architecture et les poids du modèle. Ceci est très pratique, surtout lors de l'utilisationTensorflow serve
. La façon équivalente de faire cela dans Pytorch serait:Cette méthode n'est toujours pas à l'épreuve des balles et comme pytorch subit encore de nombreux changements, je ne le recommanderais pas.
la source
torch.load
renvoie juste un OrderedDict. Comment obtenir le modèle pour faire des prédictions?La bibliothèque pickle Python implémente des protocoles binaires pour la sérialisation et la désérialisation d'un objet Python.
Lorsque vous
import torch
(ou lorsque vous utilisez PyTorch), ce seraimport pickle
pour vous et vous n'avez pas besoin d'appelerpickle.dump()
etpickle.load()
directement, quelles sont les méthodes pour enregistrer et charger l'objet.En fait,
torch.save()
ettorch.load()
emballerapickle.dump()
etpickle.load()
pour vous.Une
state_dict
autre réponse mentionnée mérite juste quelques notes supplémentaires.Qu'avons
state_dict
-nous à l'intérieur de PyTorch? Il y a en fait deuxstate_dict
art.Le modèle est PyTorch
torch.nn.Module
amodel.parameters()
appel pour obtenir des paramètres apprenables (w et b). Ces paramètres apprenables, une fois définis aléatoirement, seront mis à jour au fil du temps à mesure que nous apprenons. Les paramètres apprenables sont les premiersstate_dict
.Le second
state_dict
est le dict d'état de l'optimiseur. Vous vous rappelez que l'optimiseur est utilisé pour améliorer nos paramètres apprenables. Mais l'optimiseurstate_dict
est fixe. Rien à apprendre là-dedans.Les
state_dict
objets étant des dictionnaires Python, ils peuvent être facilement enregistrés, mis à jour, modifiés et restaurés, ajoutant une grande modularité aux modèles et optimiseurs PyTorch.Créons un modèle super simple pour expliquer cela:
Ce code affichera les éléments suivants:
Notez qu'il s'agit d'un modèle minimal. Vous pouvez essayer d'ajouter une pile de séquences
Notez que seules les couches avec des paramètres apprenables (couches convolutionnelles, couches linéaires, etc.) et des tampons enregistrés (couches batchnorm) ont des entrées dans le modèle
state_dict
.Les choses non apprenables appartiennent à l'objet optimiseur
state_dict
, qui contient des informations sur l'état de l'optimiseur, ainsi que les hyperparamètres utilisés.Le reste de l'histoire est le même; dans la phase d'inférence (c'est une phase où l'on utilise le modèle après l'entraînement) pour la prédiction; nous prédisons en fonction des paramètres que nous avons appris. Donc, pour l'inférence, nous avons juste besoin de sauvegarder les paramètres
model.state_dict()
.Et pour utiliser plus tard model.load_state_dict (torch.load (filepath)) model.eval ()
Remarque: n'oubliez pas la dernière ligne qui
model.eval()
est cruciale après le chargement du modèle.N'essayez pas non plus de sauvegarder
torch.save(model.parameters(), filepath)
. Lemodel.parameters()
n'est que l'objet générateur.De l'autre côté,
torch.save(model, filepath)
enregistre l'objet de modèle lui-même, mais gardez à l'esprit que le modèle n'a pas d'optimiseurstate_dict
. Vérifiez l'autre excellente réponse de @Jadiel de Armas pour enregistrer le dict d'état de l'optimiseur.la source
Une convention PyTorch courante consiste à enregistrer les modèles en utilisant une extension de fichier .pt ou .pth.
Enregistrer / charger tout le modèle Enregistrer:
Charge:
La classe de modèle doit être définie quelque part
la source
Si vous souhaitez enregistrer le modèle et que vous souhaitez reprendre l'entraînement plus tard:
GPU unique: Enregistrer:
Charge:
GPU multiple: enregistrer
Charge:
la source