tf.contrib.rnn.BasicLSTMCell是单个LSTM单元还是LSTM层?

Zer*_*ero 2 tensorflow

在张量流中,有一个lstm实现BasicLSTMCell,在其中调用tf.contrib.rnn.BasicLSTMCell.它有一个参数num_units,表示LSTM单元中的单元数.但我不知道这意味着什么.

如果我像这样定义一个lstm单元格:

lstm_cell = tf.contrib.rnn.BasicLSTMCell(512).
Run Code Online (Sandbox Code Playgroud)

lstm_cell是什么样的?它是一个lstm节点或一个512节点的lstm层?谁可以告诉我这个?

Giu*_*rra 5

它是一个带有512个单位的LSTM层.

BasicLSTMCell实现抽象类RNNCell.从文档:

代表RNN细胞的抽象对象.

每个RNNCell必须具有以下属性并call使用签名实现(output, next_state) = call(input, state).

[...]

细胞的这种定义不同于文献中使用的定义.在文献中,"单元格"是指具有单个标量输出的对象.该定义指的是这种单元的水平阵列.

创建LSTM层以及展开Back Propagation Trough Time的常用方法如下:

lstm_cell = tf.contrib.rnn.BasicLSTMCell(512)
outputs, final_state = tf.nn.static_rnn(cell=lstm_cell,
                           dtype=tf.float32,
                           inputs=some_input_sequence)
Run Code Online (Sandbox Code Playgroud)

哪里:

  • some_input_sequence是一个num_steps大小的张量列表[batch_size, input_size]
  • outputs将包含每个元素之后的图层输出some_input_sequence.所以它又是一个num _steps大小元素列表[batch_size, 512](其中512是你单元格的单位数)
  • final_state将在处理完整个序列后包含状态.特别是,对于LSTM,它是一个带有两个元素的命名元组,ch(LSTM的两个状态).