Tensorflow,如何访问 RNN 的所有中间状态,而不仅仅是最后一个状态

Cen*_*tAu 5 python tensorflow

我的理解是tf.nn.dynamic_rnn在每个时间步以及最终状态返回 RNN 单元(例如 LSTM)的输出。如何在所有时间步骤中访问单元格状态,而不仅仅是最后一个?例如,我希望能够平均所有隐藏状态,然后在后续层中使用它。

以下是我如何定义 LSTM 单元,然后使用tf.nn.dynamic_rnn. 但这仅给出了 LSTM 的最后一个单元状态。

import tensorflow as tf
import numpy as np

# [batch-size, sequence-length, dimensions] 
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)

outputs, last_state = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
out, last = sess.run([outputs, last_state], feed_dict=None)
Run Code Online (Sandbox Code Playgroud)

jas*_*ekp 3

像这样的东西应该有效。

import tensorflow as tf
import numpy as np


class CustomRNN(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
        self._output_size = self._state_size # change the output size to the state size
        return returns
    def __call__(self, inputs, state):
        output, next_state = super(CustomRNN, self).__call__(inputs, state)
        return next_state, next_state # return two copies of the state, instead of the output and the state

X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 10]

cell = CustomRNN(num_units=64)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
states, last_state = sess.run([outputs, last_states], feed_dict=None)
Run Code Online (Sandbox Code Playgroud)

这使用连接状态,因为我不知道是否可以存储任意数量的元组状态。states 变量的形状为(batch_size、max_time_size、state_size)。

  • LSTM 状态是输出 (m) 和隐藏状态 (c) 的组合。此代码获取输出 (m) 并将其替换为串联状态 (c + m)。不考虑批量大小,输出是 [(c1 + m1), (c2 + m2), ... ] 的列表,而不是 [m1, m2, ...]。 (2认同)