在张量流中,有一个lstm实现BasicLSTMCell,在其中调用tf.contrib.rnn.BasicLSTMCell.它有一个参数num_units,表示LSTM单元中的单元数.但我不知道这意味着什么.
如果我像这样定义一个lstm单元格:
lstm_cell = tf.contrib.rnn.BasicLSTMCell(512).
Run Code Online (Sandbox Code Playgroud)
lstm_cell是什么样的?它是一个lstm节点或一个512节点的lstm层?谁可以告诉我这个?
它是一个带有512个单位的LSTM层.
BasicLSTMCell实现抽象类RNNCell.从文档:
代表RNN细胞的抽象对象.
每个RNNCell必须具有以下属性并
call使用签名实现(output, next_state) = call(input, state).[...]
细胞的这种定义不同于文献中使用的定义.在文献中,"单元格"是指具有单个标量输出的对象.该定义指的是这种单元的水平阵列.
创建LSTM层以及展开Back Propagation Trough Time的常用方法如下:
lstm_cell = tf.contrib.rnn.BasicLSTMCell(512)
outputs, final_state = tf.nn.static_rnn(cell=lstm_cell,
dtype=tf.float32,
inputs=some_input_sequence)
Run Code Online (Sandbox Code Playgroud)
哪里:
some_input_sequence是一个num_steps大小的张量列表[batch_size, input_size] outputs将包含每个元素之后的图层输出some_input_sequence.所以它又是一个num _steps大小元素列表[batch_size, 512](其中512是你单元格的单位数)final_state将在处理完整个序列后包含状态.特别是,对于LSTM,它是一个带有两个元素的命名元组,c和h(LSTM的两个状态).| 归档时间: |
|
| 查看次数: |
1266 次 |
| 最近记录: |