mem*_*emo 2 c++ lstm tensorflow recurrent-neural-network
我想在python中构建和训练一个多层LSTM模型(stateIsTuple=True),然后在C++中加载和使用它。但是我很难弄清楚如何在 C++ 中提供和获取状态,主要是因为我没有可以引用的字符串名称。
例如,我将初始状态放在一个命名范围内,例如
with tf.name_scope('rnn_input_state'):
self.initial_state = cell.zero_state(args.batch_size, tf.float32)
Run Code Online (Sandbox Code Playgroud)
这出现在下图中,但是我如何在 C++ 中输入这些内容?
另外,如何在 C++ 中获取当前状态?我在 python 中尝试了下面的图形构造代码,但我不确定这是否正确,因为 last_state 应该是张量元组,而不是单个张量(尽管我可以看到张量板中的 last_state 节点是 2x2x50x128,这听起来像是连接了状态,因为我有 2 层,128 rnn 大小,50 迷你批量大小和 lstm 单元 - 具有 2 个状态向量)。
with tf.name_scope('outputs'):
outputs, last_state = legacy_seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None)
output = tf.reshape(tf.concat(outputs, 1), [-1, args.rnn_size], name='output')
Run Code Online (Sandbox Code Playgroud)
这就是张量板中的样子
我是否应该连接和拆分状态张量,以便只有一个状态张量进出?或者,还有更好的方法?
PS 理想情况下,该解决方案不会涉及对层数(或 rnn 大小)进行硬编码。所以我可以只有四个字符串 input_node_name、output_node_name、input_state_name、output_state_name,其余的都是从那里派生的。
我通过手动将状态连接成一个张量来做到这一点。我不确定这是否明智,因为这是 tensorflow用来处理状态的方式,但现在正在弃用它并切换到元组状态。我没有设置 state_is_tuple=False 并冒着我的代码很快过时的风险,而是添加了额外的操作来手动将状态堆叠到单个张量或从单个张量中解除堆叠。也就是说,它在 python 和 C++ 中都可以正常工作。
关键代码是:
# setting up
zero_state = cell.zero_state(batch_size, tf.float32)
state_in = tf.identity(zero_state, name='state_in')
# based on https://medium.com/@erikhallstrm/using-the-tensorflow-multilayered-lstm-api-f6e7da7bbe40#.zhg4zwteg
state_per_layer_list = tf.unstack(state_in, axis=0)
state_in_tuple = tuple(
# TODO make this not hard-coded to LSTM
[tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1])
for idx in range(num_layers)]
)
outputs, state_out_tuple = legacy_seq2seq.rnn_decoder(inputs, state_in_tuple, cell, loop_function=loop if infer else None)
state_out = tf.identity(state_out_tuple, name='state_out')
# running (training or inference)
state = sess.run('state_in:0') # zero state
loop:
feed = {'data_in:0': x, 'state_in:0': state}
[y, state] = sess.run(['data_out:0', 'state_out:0'], feed)
Run Code Online (Sandbox Code Playgroud)
如果有人需要,这里是完整的代码 https://github.com/memo/char-rnn-tensorflow