TensorFlow仅在使用MultiRNNCell时抛出错误

tao*_*oat 0 python deep-learning tensorflow recurrent-neural-network

我正在使用传统的序列到序列框架在TensorFlow 1.0.1中构建编码器 - 解码器模型.当我在编码器和解码器中有一层LSTM时,一切正常.但是,当我尝试使用包裹在a中的> 1层LSTM时MultiRNNCell,我在调用时出错tf.contrib.legacy_seq2seq.rnn_decoder.

完整的错误是在这篇文章的最后,但简而言之,它是由一条线引起的

(c_prev, m_prev) = state
Run Code Online (Sandbox Code Playgroud)

在投掷的TensorFlow中TypeError: 'Tensor' object is not iterable..我对此感到困惑,因为我传递的初始状态rnn_decoder确实是一个应该是的元组.据我所知,使用1层或> 1层的唯一区别是后者涉及使用MultiRNNCell.使用它时是否有一些我应该知道的API怪癖?

这是我的代码(基于 GitHub仓库中的示例).道歉的长度; 这是我能做到的最小化,同时仍然是完整和可验证的.

import tensorflow as tf
import tensorflow.contrib.legacy_seq2seq as seq2seq
import tensorflow.contrib.rnn as rnn

seq_len = 50
input_dim = 300
output_dim = 12
num_layers = 2
hidden_units = 100

sess = tf.Session()

encoder_inputs = []
decoder_inputs = []

for i in range(seq_len):
    encoder_inputs.append(tf.placeholder(tf.float32, shape=(None, input_dim),
                                         name="encoder_{0}".format(i)))

for i in range(seq_len + 1):
    decoder_inputs.append(tf.placeholder(tf.float32, shape=(None, output_dim),
                                         name="decoder_{0}".format(i)))

if num_layers > 1:
    # Encoder cells (bidirectional)
    # Forward
    enc_cells_fw = [rnn.LSTMCell(hidden_units)
                    for _ in range(num_layers)]
    enc_cell_fw = rnn.MultiRNNCell(enc_cells_fw)
    # Backward
    enc_cells_bw = [rnn.LSTMCell(hidden_units)
                    for _ in range(num_layers)]
    enc_cell_bw = rnn.MultiRNNCell(enc_cells_bw)
    # Decoder cell
    dec_cells = [rnn.LSTMCell(2*hidden_units)
                 for _ in range(num_layers)]
    dec_cell = rnn.MultiRNNCell(dec_cells)
else:
    # Encoder
    enc_cell_fw = rnn.LSTMCell(hidden_units)
    enc_cell_bw = rnn.LSTMCell(hidden_units)
    # Decoder
    dec_cell = rnn.LSTMCell(2*hidden_units)

# Make sure input and output are the correct dimensions
enc_cell_fw = rnn.InputProjectionWrapper(enc_cell_fw, input_dim)
enc_cell_bw = rnn.InputProjectionWrapper(enc_cell_bw, input_dim)
dec_cell = rnn.OutputProjectionWrapper(dec_cell, output_dim)

_, final_fw_state, final_bw_state = \
     rnn.static_bidirectional_rnn(enc_cell_fw,
                                  enc_cell_bw,
                                  encoder_inputs,
                                  dtype=tf.float32)

# Concatenate forward and backward cell states
# (The state is a tuple of previous output and cell state)
if num_layers == 1:
    initial_dec_state = tuple([tf.concat([final_fw_state[i],
                                          final_bw_state[i]], 1) 
                               for i in range(2)])
else:
    initial_dec_state = tuple([tf.concat([final_fw_state[-1][i],
                                          final_bw_state[-1][i]], 1) 
                               for i in range(2)])

decoder = seq2seq.rnn_decoder(decoder_inputs, initial_dec_state, dec_cell)

tf.global_variables_initializer().run(session=sess)
Run Code Online (Sandbox Code Playgroud)

这是错误:

Traceback (most recent call last):
  File "example.py", line 67, in <module>
    decoder = seq2seq.rnn_decoder(decoder_inputs, initial_dec_state, dec_cell)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py", line 150, in rnn_decoder
    output, state = cell(inp, state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 426, in __call__
    output, res_state = self._cell(inputs, state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 655, in __call__
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 321, in __call__
    (c_prev, m_prev) = state
  File "/home/tao/.virtualenvs/example/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 502, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
Run Code Online (Sandbox Code Playgroud)

谢谢!

小智 6

问题在于initial_dec_state传递给的初始state()的格式seq2seq.rnn_decoder.

当您使用时rnn.MultiRNNCell,您正在构建多层循环网络,因此您需要为每个层提供初始状态.

因此,您应该提供元组列表作为初始状态,其中列表的每个元素都是来自循环网络的相应层的先前状态.

所以你的initial_dec_state,初始化如下:

    initial_dec_state = tuple([tf.concat([final_fw_state[-1][i],
                                      final_bw_state[-1][i]], 1) 
                           for i in range(2)])
Run Code Online (Sandbox Code Playgroud)

相反应该是这样的:

    initial_dec_state = [
                    tuple([tf.concat([final_fw_state[j][i],final_bw_state[j][i]], 1) 
                           for i in range(2)]) for j in range(len(final_fw_state))
                        ]
Run Code Online (Sandbox Code Playgroud)

它以以下格式创建元组列表:

    [(state_c1, state_m1), (state_c2, state_m2) ...]
Run Code Online (Sandbox Code Playgroud)

更详细地说,'Tensor' object is not iterable.发生错误是因为seq2seq.rnn_decoder内部调用你的rnn.MultiRNNCell(dec_cell)将初始状态(initial_dec_state)传递给它.

rnn.MultiRNNCell.__call__迭代初始状态列表,并为每个状态提取元组(c_prev, m_prev)(在语句中(c_prev, m_prev) = state).

因此,如果你只传递一个元组,rnn.MultiRNNCell.__call__将迭代它,并且一旦到达(c_prev, m_prev) = state它将找到一个张量(应该是一个元组)state,并将抛出'Tensor' object is not iterable.错误.

了解a seq2seq.rnn_decoder期望的初始状态格式的一种好方法是调用dec_cell.zero_state(batch_size, dtype=tf.float32).此方法以初始化您正在使用的循环模块所需的确切格式返回零填充状态张量.