Quels sont exactement les mécanismes d'attention?

23

Les mécanismes d'attention ont été utilisés dans divers articles sur le Deep Learning au cours des dernières années. Ilya Sutskever, responsable de la recherche chez Open AI, les a félicités avec enthousiasme: https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Eugenio Culurciello de l'Université Purdue a déclaré que les RNN et les LSTM devraient être abandonnés au profit de réseaux de neurones purement basés sur l'attention:

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Cela semble exagéré, mais il est indéniable que les modèles purement basés sur l'attention ont assez bien réussi dans les tâches de modélisation séquentielle: nous connaissons tous le papier bien nommé de Google, l' attention est tout ce dont vous avez besoin

Cependant, quels sont exactement les modèles basés sur l'attention? Je n'ai pas encore trouvé d'explication claire sur ces modèles. Supposons que je veuille prévoir les nouvelles valeurs d'une série chronologique multivariée, compte tenu de ses valeurs historiques. Il est assez clair comment faire cela avec un RNN ayant des cellules LSTM. Comment pourrais-je faire de même avec un modèle basé sur l'attention?

DeltaIV
la source

Réponses:

20

L'attention est une méthode d'agrégation d'un ensemble de vecteurs vi en un seul vecteur, souvent via un vecteur de recherche u . Habituellement, vi est soit les entrées du modèle, soit les états cachés des pas de temps précédents, soit les états cachés un niveau plus bas (dans le cas des LSTM empilés).

Le résultat est souvent appelé le vecteur de contexte c , car il contient le contexte correspondant au pas de temps actuel.

Ce vecteur de contexte supplémentaire c est ensuite également introduit dans le RNN / LSTM (il peut être simplement concaténé avec l'entrée d'origine). Par conséquent, le contexte peut être utilisé pour aider à la prédiction.

La façon la plus simple de le faire est de calculer le vecteur de probabilité p=softmax(VTu) et c=ipiviV est la concaténation de tous les vi précédents . Un vecteur de recherche commun u est l'état caché actuel ht .

Il existe de nombreuses variantes à ce sujet et vous pouvez rendre les choses aussi compliquées que vous le souhaitez. Par exemple, au lieu d'utiliser vjeTu comme logits, on peut choisir F(vje,u) place, où F est un réseau neuronal arbitraire.

Un mécanisme d'attention commun pour les modèles de séquence à séquence utilise p=softmax(qTtanh(W1vje+W2ht)) , où v sont les états cachés du codeur et ht est l'état caché actuel du décodeur. q et les deux W s sont des paramètres.

Quelques articles qui montrent différentes variations sur l'idée d'attention:

Les réseaux de pointeurs font attention aux entrées de référence afin de résoudre les problèmes d'optimisation combinatoire.

Les réseaux d'entités récurrents maintiennent des états de mémoire distincts pour différentes entités (personnes / objets) lors de la lecture de texte et mettent à jour l'état de mémoire correct en faisant attention.

Les modèles de transformateurs font également largement appel à l'attention. Leur formulation de l'attention est légèrement plus générale et implique également des vecteurs clés kje : les poids d'attention p sont en fait calculés entre les clés et la recherche, et le contexte est ensuite construit avec le vje .


Voici une mise en œuvre rapide d'une forme d'attention, bien que je ne puisse garantir l'exactitude au-delà du fait qu'elle a réussi quelques tests simples.

RNN de base:

def rnn(inputs_split):
    bias = tf.get_variable('bias', shape = [hidden_dim, 1])
    weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
    weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

    hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
    for i, input in enumerate(inputs_split):
        input = tf.reshape(input, (batch, in_dim, 1))
        last_state = hidden_states[-1]
        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
        hidden_states.append(hidden)
    return hidden_states[-1]

Avec attention, nous ajoutons seulement quelques lignes avant que le nouvel état caché soit calculé:

        if len(hidden_states) > 1:
            logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
            probs = tf.nn.softmax(logits)
            probs = tf.reshape(probs, (batch, -1, 1, 1))
            context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
        else:
            context = tf.zeros_like(last_state)

        last_state = tf.concat([last_state, context], axis = 1)

        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

le code complet

shimao
la source
p=softmax(VTu)jec=jepjevjepjeVTvVTv
1
zje=vjeTup=softmax(z)pje=ejezjejz
ppje
1
oui, c'est ce que je voulais dire
shimao
@shimao J'ai créé une salle de chat , faites-moi savoir si vous seriez intéressé à parler (pas à propos de cette question)
DeltaIV