我是TensorFlow的新手,很难理解RNN模块.我试图从LSTM中提取隐藏/单元状态.对于我的代码,我使用的是https://github.com/aymericdamien/TensorFlow-Examples中的实现.
# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])
# Define weights
weights = {'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))}
biases = {'out': tf.Variable(tf.random_normal([n_classes]))}
def RNN(x, weights, biases):
# Prepare data shape to match `rnn` function requirements
# Current data input shape: (batch_size, n_steps, n_input)
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
# Permuting batch_size and n_steps
x = tf.transpose(x, [1, 0, 2])
# Reshaping to (n_steps*batch_size, n_input)
x = tf.reshape(x, …Run Code Online (Sandbox Code Playgroud)