如何在没有固定batch_size的情况下设置Tensorflow dynamic_rnn,zero_state?

Dav*_*vid 7 python tensorflow

根据Tensorflow的官方网站(https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell#zero_state),zero_state必须指定batch_size.我发现的许多例子使用此代码:

    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)

    outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)
Run Code Online (Sandbox Code Playgroud)

对于培训步骤,可以修复批量大小.但是,在预测时,测试集的形状可能与训练集的批量大小不同.例如,我的一批训练数据具有形状[100,255,128].批量大小为100,包含255个步骤和128个输入.而测试集是[2000,255,128].我无法预测,因为在dynamic_rnn(initial_state)中,它已经设置了一个固定的batch_size = 100.如何修复此问题?

谢谢.

VS_*_*_FF 11

您可以将其指定batch_size为占位符,而不是常量.只需确保输入相关的数字feed_dict,这对于培训和测试都是不同的

重要的是,指定[]占位符的维度,因为如果指定None,可能会出现错误,这在其他地方也是如此.所以像这样的东西应该工作:

batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})
Run Code Online (Sandbox Code Playgroud)

显然,请确保批处理参数与输入的形状相匹配,这dynamic_rnn将解释为[batch_size, seq_len, features][seq_len, batch_size, features]如果time_major设置为True