访问RNN权重 - Tensorflow

Rag*_*lli 6 python lstm tensorflow recurrent-neural-network

我使用的是tf.python.ops.rnn_cell.GRUCell

output, state = tf.nn.dynamic_rnn(
        GRUCell(HID_DIM),
        sequence,
        dtype=tf.float32,
        sequence_length=length(sequence)
)
Run Code Online (Sandbox Code Playgroud)

如何获得此GRUCell的权重.我需要看看它们进行调试.

Fil*_*ixo 2

可以使用以下命令打印当前会话中所有变量的值:

with tf.Session() as sess:
    variables_names =[v.name for v in tf.trainable_variables()]
    values = sess.run(variables_names)
    for k,v in zip(variables_names, values):
        print(k, v)
Run Code Online (Sandbox Code Playgroud)