TensorFlow:从RNN获取所有状态

Ank*_*ani 11 python machine-learning deep-learning tensorflow

你如何从所有隐藏的状态tf.nn.rnn()tf.nn.dynamic_rnn()在TensorFlow?API只给我最终状态.

第一种选择是在构建直接在RNNCell上运行的模型时编写循环.但是,对于我来说,时间步长的数量并不固定,并且取决于传入的批次.

一些选项是使用GRU或编写我自己的RNNCell,将状态连接到输出.前者的选择不够普遍,后者听起来太过于苛刻.

另一个选择是做一些像这个问题中的答案,从RNN获取所有变量.但是,我不确定如何以标准方式将隐藏状态与其他变量分开.

在使用库提供的RNN API时,是否有一种很好的方法可以从RNN获取所有隐藏状态?

Car*_*910 0

我已经在这里创建了一个 PR ,它可能会帮助您处理简单的情况

让我简要解释一下我的实现,以便您可以根据需要编写自己的版本。主要部分是函数的修改_time_step

def _time_step(time, output_ta_t, state, *args):
Run Code Online (Sandbox Code Playgroud)

*args除了传入额外参数之外,参数保持不变。但是为什么呢args?那是因为我想支持tensorflow的习惯行为。您只需忽略args参数即可返回最终状态:

if states_ta is not None:
    # If you want to return all states, set `args` to be `states_ta`
    loop_vars = (time, output_ta, state, states_ta)
else:
    # If you want the final state only, ignore `args`
    loop_vars = (time, output_ta, state)
Run Code Online (Sandbox Code Playgroud)

如何利用它?

if args:
    args = tuple(
        ta.write(time, out) for ta, out in zip(args[0], [new_state])
    )
Run Code Online (Sandbox Code Playgroud)

事实上,这只是以下(原始)代码的修改:

output_ta_t = tuple(
    ta.write(time, out) for ta, out in zip(output_ta_t, output)
)
Run Code Online (Sandbox Code Playgroud)

现在args应该包含您想要的所有状态。

完成上述所有工作后,您可以使用以下代码获取状态(或最终状态):

_, output_final_ta, *state_info = control_flow_ops.while_loop( ...
Run Code Online (Sandbox Code Playgroud)

if states_ta is not None:
    final_state, states_final_ta = state_info
else:
    final_state, states_final_ta = state_info[0], None
Run Code Online (Sandbox Code Playgroud)

虽然我还没有在复杂的情况下测试它,但它应该在“简单”条件下工作(这是我的测试用例)