Perte soudaine de précision lors de la formation LSTM ou GRU à Keras

8

Mon réseau neuronal récurrent (LSTM, resp. GRU) se comporte d'une manière que je ne peux pas expliquer. L'entraînement commence et il s'entraîne bien (les résultats semblent assez bons) lorsque la précision diminue soudainement (et que la perte augmente rapidement) - à la fois les mesures d'entraînement et de test. Parfois, le net devient fou et renvoie des sorties aléatoires et parfois (comme dans le dernier des trois exemples donnés), il commence à renvoyer la même sortie à toutes les entrées .

image

Avez-vous une explication à ce comportement ? Toute opinion est la bienvenue. Veuillez consulter la description de la tâche et les figures ci-dessous.

La tâche: à partir d'un mot prédire son vecteur word2vec L'entrée: Nous avons un propre modèle word2vec (normalisé) et nous alimentons le réseau avec un mot (lettre par lettre). Nous remplissons les mots (voir l'exemple ci-dessous). Exemple: Nous avons un mot football et nous voulons prédire son vecteur word2vec qui fait 100 dimensions de large. Ensuite, l'entrée est $football$$$$$$$$$$.

Trois exemples de comportement:

LSTM monocouche

model = Sequential([
    LSTM(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

image

GRU monocouche

model = Sequential([
    GRU(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

image

LSTM double couche

model = Sequential([
    LSTM(512, input_shape=encoder.shape, return_sequences=True),
    TimeDistributed(Dense(512, activation="sigmoid")),
    LSTM(512, return_sequences=False),
    Dense(256, activation="tanh"),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

image

Nous avons également expérimenté ce type de comportement dans un autre projet auparavant qui utilisait une architecture similaire mais son objectif et ses données étaient différents. Ainsi, la raison ne doit pas être cachée dans les données ou dans l'objectif particulier mais plutôt dans l'architecture.

Marek
la source
avez-vous découvert la cause du problème?
Antoine
Malheureusement pas vraiment. Nous avons changé pour une architecture différente et ensuite nous n'avons pas eu l'occasion d'y revenir. Nous avons cependant quelques indices. Nous pensons que quelque chose a provoqué la modification d'un ou de plusieurs paramètres nan.
Marek
nanLe paramètre n'entraînerait pas de perte non nanométrique. Je suppose que vos gradients explosent, une chose similaire m'est arrivée dans des réseaux normalisés non batch.
Lugi
C'est également l'une des choses que nous avons essayé d'examiner à l'aide de TensorBoard, mais une explosion de gradient n'a jamais été prouvée dans notre cas. L'idée était qu'elle nanapparaissait dans l'un des calculs, puis elle était par défaut dans une autre valeur qui a rendu le réseau fou. Mais c'est juste une supposition folle. Merci pour votre avis.
Marek

Réponses:

2

Voici ma suggestion pour identifier le problème:

1) Regardez la courbe d'apprentissage de la formation: Comment est la courbe d'apprentissage sur le train? Apprend-il l'ensemble de formation? Si ce n'est pas le cas, commencez par y travailler pour vous assurer que vous pouvez vous adapter à l'ensemble d'entraînement.

2) Vérifiez vos données pour vous assurer qu'elles ne contiennent pas de NaN (formation, validation, test)

3) Vérifiez les gradients et les poids pour vous assurer qu'il n'y a pas de NaN.

4) Réduisez le taux d'apprentissage pendant que vous vous entraînez pour vous assurer que ce n'est pas à cause d'une grosse mise à jour soudaine coincée dans un minimum strict.

5) Pour vous assurer que tout va bien, vérifiez les prédictions de votre réseau afin que votre réseau ne fasse pas de prédictions constantes ou répétitives.

6) Vérifiez si vos données dans votre lot sont équilibrées par rapport à toutes les classes.

7) normalisez vos données pour qu'elles correspondent à une variance unitaire moyenne nulle. Initialisez également les poids. Il facilitera la formation.

PickleRick
la source