Ank*_*ani 11 python machine-learning deep-learning tensorflow
你如何从所有隐藏的状态tf.nn.rnn()或tf.nn.dynamic_rnn()在TensorFlow?API只给我最终状态.
第一种选择是在构建直接在RNNCell上运行的模型时编写循环.但是,对于我来说,时间步长的数量并不固定,并且取决于传入的批次.
一些选项是使用GRU或编写我自己的RNNCell,将状态连接到输出.前者的选择不够普遍,后者听起来太过于苛刻.
另一个选择是做一些像这个问题中的答案,从RNN获取所有变量.但是,我不确定如何以标准方式将隐藏状态与其他变量分开.
在使用库提供的RNN API时,是否有一种很好的方法可以从RNN获取所有隐藏状态?
我已经在这里创建了一个 PR ,它可能会帮助您处理简单的情况
让我简要解释一下我的实现,以便您可以根据需要编写自己的版本。主要部分是函数的修改_time_step:
def _time_step(time, output_ta_t, state, *args):
Run Code Online (Sandbox Code Playgroud)
*args除了传入额外参数之外,参数保持不变。但是为什么呢args?那是因为我想支持tensorflow的习惯行为。您只需忽略args参数即可返回最终状态:
if states_ta is not None:
# If you want to return all states, set `args` to be `states_ta`
loop_vars = (time, output_ta, state, states_ta)
else:
# If you want the final state only, ignore `args`
loop_vars = (time, output_ta, state)
Run Code Online (Sandbox Code Playgroud)
如何利用它?
if args:
args = tuple(
ta.write(time, out) for ta, out in zip(args[0], [new_state])
)
Run Code Online (Sandbox Code Playgroud)
事实上,这只是以下(原始)代码的修改:
output_ta_t = tuple(
ta.write(time, out) for ta, out in zip(output_ta_t, output)
)
Run Code Online (Sandbox Code Playgroud)
现在args应该包含您想要的所有状态。
完成上述所有工作后,您可以使用以下代码获取状态(或最终状态):
_, output_final_ta, *state_info = control_flow_ops.while_loop( ...
Run Code Online (Sandbox Code Playgroud)
和
if states_ta is not None:
final_state, states_final_ta = state_info
else:
final_state, states_final_ta = state_info[0], None
Run Code Online (Sandbox Code Playgroud)
虽然我还没有在复杂的情况下测试它,但它应该在“简单”条件下工作(这是我的测试用例)
| 归档时间: |
|
| 查看次数: |
1882 次 |
| 最近记录: |