在tensorflow中获取dynamic_rnn的最后一个输出?

app*_*der 17 python tensorflow

我使用dynamic_rnn来处理MNIST数据:

# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
                         forget_bias=1.0,
                         initializer=tf.random_normal)

# Initial state
istate = lstm.zero_state(batch_size, "float")

# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)

# Output at last time point T
output_at_T = output[:, 27, :]
Run Code Online (Sandbox Code Playgroud)

完整代码:http://pastebin.com/bhf9MgMe

lstm的输入是 (batch_size, sequence_length, input_size)

因此,尺寸output_at_T(batch_size, sequence_length, num_units)在哪里num_units=200.

我需要沿sequence_length 维度获取最后一个输出.在上面的代码中,这是硬编码的27.但是,我sequence_length事先并不知道,因为它可以在我的应用程序中从批处理更改为批处理.

我试过了:

output_at_T = output[:, -1, :]
Run Code Online (Sandbox Code Playgroud)

但是它说负面索引还没有实现,我尝试使用占位符变量和常量(我可以理想地sequence_length为特定批次提供); 既没有奏效.

有什么方法可以在tensorflow atm中实现这样的东西吗?

Esc*_*tor 13

您是否注意到dynamic_rnn有两个输出?

  1. 输出1,我们称之为h,每个时间步都有输出(即h_1,h_2等),
  2. 输出2,final_state,有两个元素:cell_state,以及批处理中每个元素的最后一个输出(只要您将序列长度输入到dynamic_rnn).

所以来自:

h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )
Run Code Online (Sandbox Code Playgroud)

批处理中每个元素的最后一个状态是:

final_state.h
Run Code Online (Sandbox Code Playgroud)

请注意,这包括当序列的长度对于批处理的每个元素不同时的情况,因为我们使用sequence_length参数.


Ale*_*lex 5

这就是gather_nd的用途!

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.gather_nd(data, indices)

    return res
Run Code Online (Sandbox Code Playgroud)

在您的情况下(假设sequence_length是一个具有每个轴0元素长度的1-D张量):

output = extract_axis_1(output, sequence_length - 1)
Run Code Online (Sandbox Code Playgroud)

现在输出是一个维度的张量[batch_size, num_cells].


Ben*_*ner 0

您应该能够output使用访问张量的形状tf.shape(output)。该tf.shape()函数将返回一个包含张量大小的一维张量output。在你的例子中,这将是(batch_size, sequence_length, num_units)

然后您应该能够提取output_at_Tas的值output[:, tf.shape(output)[1], :]