Tensorflow:如何从rnn_cell.BasicLSTM&rnn_cell.MultiRNNCell获取所有变量

bge*_*ge0 17 python tensorflow

我有一个设置,我需要在主要初始化后使用初始化LSTM tf.initialize_all_variables().即我想打电话tf.initialize_variables([var_list])

有没有办法收集所有内部可训练变量:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

这样我可以初始化JUST这些参数吗?

我想要这个的主要原因是因为我不想重新初始化一些训练过的值.

Raf*_*icz 17

解决问题的最简单方法是使用变量范围.范围内变量的名称将以其名称为前缀.这是一个简短的片段:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)
Run Code Online (Sandbox Code Playgroud)

它会以同样的方式工作MultiRNNCell.

编辑:改变tf.trainable_variablestf.all_variables()


小智 11

您还可以使用tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)
Run Code Online (Sandbox Code Playgroud)

(部分复制自拉法尔的回答)

请注意,最后一行等同于Rafal代码中的列表推导.

基本上,tensorflow存储的变量,它可以用中获取一个全球征集tf.all_variables()tf.get_collection(tf.GraphKeys.VARIABLES).如果scopetf.get_collection()函数中指定(范围名称),则只能在范围在指定范围内的集合中获取张量(在本例中为变量).

编辑:您也可以tf.GraphKeys.TRAINABLE_VARIABLES用来获取可训练的变量.但由于vanilla BasicLSTMCell没有初始化任何不可训练的变量,因此两者在功能上都是等价的.有关默认图表集合的完整列表,请查看此项.