Dan*_*Fox 6 python neural-network lstm tensorflow
嗨我正在使用以下函数为lstm rnn单元格.
def LSTM_RNN(_X, _istate, _weights, _biases):
# Function returns a tensorflow LSTM (RNN) artificial neural network from given parameters.
# Note, some code of this notebook is inspired from an slightly different
# RNN architecture used on another dataset:
# https://tensorhub.com/aymericdamien/tensorflow-rnn
# (NOTE: This step could be greatly optimised by shaping the dataset once
# input shape: (batch_size, n_steps, n_input)
_X = tf.transpose(_X, [1, 0, 2]) # permute n_steps and batch_size
# Reshape to prepare input to hidden activation
_X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)
# Linear activation
_X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']
# Define a lstm cell with tensorflow
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Split data because rnn cell needs a list of inputs for the RNN inner loop
_X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)
# Get lstm cell output
outputs, states = rnn.rnn(lstm_cell, _X, initial_state=_istate)
# Linear activation
# Get inner loop last output
return tf.matmul(outputs[-1], _weights['out']) + _biases['out']
Run Code Online (Sandbox Code Playgroud)
函数的输出存储在pred变量下.
pred = LSTM_RNN(x, istate, weights, biases)
但它显示以下错误.(表明张量对象不可迭代.)
这是ERROR图像链接 - http://imgur.com/a/NhSFK
请帮助我,如果这个问题看起来很愚蠢,我很抱歉,因为我对lstm和tensor流程库很新.
谢谢.
当它试图state用语句解压时发生错误c, h=state.根据其tensorflow你正在使用的版本(你可以通过键入检查版本信息import tensorflow; tensorflow.__version__r0.11之前版本在Python解释器),默认设置的state_is_tuple参数,当你初始化rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)设置为False.请参阅此处的文档.
由于tensorflow版本r0.11(或主版),默认设置为state_is_tuple被设定为True.请参阅此处的文档.
如果安装了r0.11或tensorflow的主版本,请尝试将BasicLSTMCell初始化行更改为:
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=False).您遇到的错误应该消失.虽然,他们的页面确实表示该state_is_tuple=False行为很快就会被弃用.