Quelle est la sortie d'un tf.nn.dynamic_rnn ()?

8

Je ne suis pas sûr de ce que je comprends de la documentation officielle, qui dit:

Renvoie: Une paire (sorties, état) où:

outputs: Le tenseur de sortie RNN.

Si time_major == False( par défaut), ce sera une forme Tensor: [batch_size, max_time, cell.output_size].

Si time_major == True, ce sera une forme Tensor: [max_time, batch_size, cell.output_size].

Remarque: s'il cell.output_sizes'agit d'un tuple (éventuellement imbriqué) d'entiers ou d'objets TensorShape, les sorties seront un tuple ayant la même structure que cell.output_size, contenant des tenseurs ayant des formes correspondant aux données de forme dans cell.output_size.

state: L'état final. Si cell.state_size est un int, il sera mis en forme [batch_size, cell.state_size]. S'il s'agit d'une TensorShape, celle-ci sera mise en forme [batch_size] + cell.state_size. S'il s'agit d'un tuple (éventuellement imbriqué) d'ints ou de TensorShape, ce sera un tuple ayant les formes correspondantes. Si les cellules sont des cellules LSTMC, l'état sera un tuple contenant un LSTMStateTuple pour chaque cellule.

output[-1] Est-il toujours (dans les trois types de cellules, c'est-à-dire RNN, GRU, LSTM) égal à l'état (deuxième élément du tuple de retour)? Je suppose que la littérature partout dans le monde est trop libérale dans l'utilisation du terme état caché. Est-ce que l'état caché dans les trois cellules est le score qui sort (pourquoi il est appelé caché me dépasse, il semblerait que l'état de cellule dans LSTM devrait être appelé l'état caché car il n'est pas exposé)?

MiloMinderbinder
la source

Réponses:

10

Oui, la sortie de cellule est égale à l'état caché. Dans le cas de LSTM, c'est la partie à court terme du tuple (deuxième élément de LSTMStateTuple), comme on peut le voir sur cette image:

LSTM

Mais pour tf.nn.dynamic_rnn, l' état retourné peut être différent lorsque la séquence est plus courte ( sequence_lengthargument). Jetez un œil à cet exemple:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

Ici, le lot d'entrée contient 4 séquences et l'une d'entre elles est courte et remplie de zéros. Lors de l'exécution, vous devriez quelque chose comme ceci:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

... ce qui en effet le montre state == output[1]pour les séquences complètes et state == output[0]pour la courte. Est également output[1]un vecteur zéro pour cette séquence. Il en va de même pour les cellules LSTM et GRU.

Il states'agit donc d'un tenseur pratique qui contient le dernier état RNN réel , en ignorant les zéros. Le outputtenseur détient les sorties de toutes les cellules, il n'ignore donc pas les zéros. C'est la raison de leur retour tous les deux.

Maxime
la source
2

Copie possible de /programming/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930

Quoi qu'il en soit, allons de l'avant avec la réponse.

Cet extrait de code pourrait aider à comprendre ce qui est réellement retourné par la dynamic_rnncouche

=> Tuple de (sorties, final_output_state) .

Ainsi, pour une entrée avec une longueur de séquence maximale de T pas de temps, les sorties ont la forme [Batch_size, T, num_inputs](donnée time_major= Faux; valeur par défaut) et contiennent l'état de sortie à chaque pas de temps h1, h2.....hT.

Et final_output_state a la forme [Batch_size,num_inputs]et a l'état de cellule final cTet l'état hTde sortie de chaque séquence de lot.

Mais puisque le dynamic_rnnest utilisé, je suppose que la longueur de votre séquence varie pour chaque lot.

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

L'affirmation finale échouera car l'état final de la 2ème séquence est au 6ème pas de temps ie. l'index 5 et le reste des sorties de [6: 9] sont tous des 0 dans le 2ème pas de temps

Bhaskar Arun
la source