Comment les gradients se propagent-ils dans un réseau neuronal récurrent non déroulé?

8

J'essaie de comprendre comment les rnn peuvent être utilisés pour prédire des séquences en travaillant à travers un exemple simple. Voici mon réseau simple, composé d'une entrée, d'un neurone caché et d'une sortie:

entrez la description de l'image ici

Le neurone caché est la fonction sigmoïde et la sortie est considérée comme une simple sortie linéaire. Donc, je pense que le réseau fonctionne comme suit: si l'unité cachée commence dans l'état s, et que nous traitons un point de données qui est une séquence de longueur , , alors:3(x1,x2,x3)

Au moment 1, la valeur prédite, , estp1

p1=u×σ(ws+vx1)

Parfois 2, nous avons

p2=u×σ(w×σ(ws+vx1)+vx2)

Parfois 3, nous avons

p3=u×σ(w×σ(w×σ(ws+vx1)+vx2)+vx3)

Jusqu'ici tout va bien?

Le rnn "déroulé" ressemble à ceci:

entrez la description de l'image ici

Si nous utilisons une somme de termes d'erreur carrés pour la fonction objectif, comment est-elle définie? Sur toute la séquence? Dans ce cas, nous aurions quelque chose comme ?E=(p1x1)2+(p2x2)2+(p3x3)2

Les poids sont-ils mis à jour uniquement une fois que la séquence entière a été examinée (dans ce cas, la séquence en 3 points)?

En ce qui concerne le gradient par rapport aux poids, nous devons calculer , je vais essayer de le faire simplement en examinant les 3 équations de ci-dessus, si tout le reste semble correct. En plus de le faire de cette façon, cela ne ressemble pas à une rétro-propagation de la vanille, car les mêmes paramètres apparaissent dans différentes couches du réseau. Comment ajustons-nous cela?dE/dw,dE/dv,dE/dupi

Si quelqu'un peut m'aider à travers cet exemple de jouet, je serais très reconnaissant.

Fequish
la source
Je pense que quelque chose ne va pas avec la fonction d'erreur, vous obtenez probablement comme deuxième terme d'élément et vous devez le comparer probablement avec , dans le cas parfait, ils doivent être égaux. Dans votre fonction d'erreur, vous comparez simplement l'entrée et la sortie du réseau. p1x2
itdxer
J'ai pensé que cela pourrait être le cas. Mais alors comment l'erreur est-elle définie pour le dernier élément prédit, ? p3
Fequish

Réponses:

1

Je pense que vous avez besoin de valeurs cibles. Donc, pour la séquence , vous auriez besoin de cibles correspondantes . Puisque vous semblez vouloir prédire le prochain terme de la séquence d'entrée d'origine, vous aurez besoin de: (x1,x2,x3)(t1,t2,t3)

t1=x2, t2=x3, t3=x4

Vous auriez besoin de définir , donc si vous aviez une séquence d'entrée de longueur pour former le RNN, vous ne seriez en mesure d'utiliser que les premiers termes comme valeurs d'entrée et les derniers termes comme cible valeurs.x4NN1N1

Si nous utilisons une somme de termes d'erreur carrés pour la fonction objectif, comment est-elle définie?

Pour autant que je sache, vous avez raison - l'erreur est la somme sur toute la séquence. Cela est dû au fait que les poids , et sont les mêmes à travers le RNN déplié.uvw

Donc,

E=tEt=t(ttpt)2

Les poids sont-ils mis à jour uniquement une fois que la séquence entière a été examinée (dans ce cas, la séquence en 3 points)?

Oui, si j'utilise la propagation arrière dans le temps, je le pense.

En ce qui concerne les différentiels, vous ne voudrez pas étendre l'expression entière pour et la différencier quand il s'agit de RNN plus grands. Ainsi, certaines notations peuvent le rendre plus net:E

  • Soit signal d'entrée du neurone caché au temps (ie )zttz1=ws+vx1
  • Soit la sortie du neurone caché au temps (ie ytty1=σ(ws+vx1))
  • Soity0=s
  • Soitδt=Ezt

Ensuite, les dérivés sont:

Eu=ytEv=tδtxtEw=tδtyt1

Où pour une séquence de longueur , et:t[1, T]T

δt=σ(zt)(u+δt+1w)

Cette relation récurrente vient du fait de réaliser que l' activité cachée non seulement affecte l'erreur à la sortie , , mais elle affecte également le reste de l'erreur plus bas sur le RNN, :tthtthEtEEt

Ezt=Etytytzt+(EEt)zt+1zt+1ytytztEzt=ytzt(Etyt+(EEt)zt+1zt+1yt)Ezt=σ(zt)(u+(EEt)zt+1w)δt=Ezt=σ(zt)(u+δt+1w)

En plus de le faire de cette façon, cela ne ressemble pas à une rétro-propagation de la vanille, car les mêmes paramètres apparaissent dans différentes couches du réseau. Comment ajustons-nous cela?

Cette méthode est appelée rétropropagation dans le temps (BPTT) et est similaire à la rétropropagation dans le sens où elle utilise l'application répétée de la règle de chaîne.

Un exemple travaillé plus détaillé mais compliqué pour un RNN peut être trouvé dans le chapitre 3.2 de 'Étiquetage de séquence supervisé avec des réseaux de neurones récurrents' par Alex Graves - lecture vraiment intéressante!

dok
la source
0

Erreur que vous décrivez ci-dessus (après une modification que j'ai écrite en commentaire sous la question), vous ne pouvez l'utiliser que comme une erreur de prédiction totale, mais vous ne pouvez pas l'utiliser dans le processus d'apprentissage. À chaque itération, vous mettez une valeur d'entrée dans le réseau et obtenez une sortie. Lorsque vous obtenez une sortie, vous devez vérifier le résultat de votre réseau et propager l'erreur à tous les poids. Après la mise à jour, vous placerez la valeur suivante dans l'ordre et effectuerez une prédiction pour cette valeur, puis vous propagerez également l'erreur, etc.

itdxer
la source