bge*_*ge0 17 python tensorflow
我有一个设置,我需要在主要初始化后使用初始化LSTM tf.initialize_all_variables().即我想打电话tf.initialize_variables([var_list])
有没有办法收集所有内部可训练变量:
这样我可以初始化JUST这些参数吗?
我想要这个的主要原因是因为我不想重新初始化一些训练过的值.
Raf*_*icz 17
解决问题的最简单方法是使用变量范围.范围内变量的名称将以其名称为前缀.这是一个简短的片段:
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)
# Retrieve just the LSTM variables.
lstm_variables = [v for v in tf.all_variables()
if v.name.startswith(vs.name)]
# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)
Run Code Online (Sandbox Code Playgroud)
它会以同样的方式工作MultiRNNCell.
编辑:改变tf.trainable_variables以tf.all_variables()
小智 11
您还可以使用tf.get_collection():
cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
# Execute the LSTM cell here in any way, for example:
for i in range(num_steps):
output[i], state = cell(input_data[i], state)
lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)
Run Code Online (Sandbox Code Playgroud)
(部分复制自拉法尔的回答)
请注意,最后一行等同于Rafal代码中的列表推导.
基本上,tensorflow存储的变量,它可以用中获取一个全球征集tf.all_variables()或tf.get_collection(tf.GraphKeys.VARIABLES).如果scope在tf.get_collection()函数中指定(范围名称),则只能在范围在指定范围内的集合中获取张量(在本例中为变量).
编辑:您也可以tf.GraphKeys.TRAINABLE_VARIABLES用来获取可训练的变量.但由于vanilla BasicLSTMCell没有初始化任何不可训练的变量,因此两者在功能上都是等价的.有关默认图表集合的完整列表,请查看此项.
| 归档时间: |
|
| 查看次数: |
11306 次 |
| 最近记录: |