Wei*_*Liu 34 tensorflow pre-trained-model
我在范围内创建了一个可训练的变量.后来,我进入相同的范围,设置范围reuse_variables,并用于get_variable检索相同的变量.但是,我无法将变量的可训练属性设置为False.我的get_variable界限如下:
weight_var = tf.get_variable('weights', trainable = False)
Run Code Online (Sandbox Code Playgroud)
但变量'weights'仍在输出中tf.trainable_variables.
我可以使用?设置共享变量的trainable标志吗?Falseget_variable
我想这样做的原因是我试图在我的模型中重用从VGG网络预训练的低级过滤器,我想像以前一样构建图形,检索权重变量,并分配VGG过滤器值到重量变量,然后在下面的训练步骤中保持它们固定.
Oli*_*rot 28
在查看文档和代码之后,我无法找到从中删除变量的方法TRAINABLE_VARIABLES.
tf.get_variable('weights', trainable=True)调用时,变量被添加到列表中TRAINABLE_VARIABLES.tf.get_variable('weights', trainable=False),你得到相同的变量,但是参数trainable=False没有效果,因为变量已经存在于列表中TRAINABLE_VARIABLES(并且无法从那里删除它)在调用minimize优化器的方法时(参见文档),您可以var_list=[...]使用要优化的变量传递as参数.
例如,如果要冻结除最后两个之外的所有VGG图层,则可以传递最后两个图层的权重var_list.
您可以使用a tf.train.Saver()保存变量并在以后恢复它们(请参阅本教程).
saver.save(sess, "/path/to/dir/model.ckpt").saver.restore(sess, "/path/to/dir/model.ckpt").(可选)您可以决定仅在检查点文件中保存一些变量.有关详细信息,请参阅doc.
roc*_*yne 10
当您只想训练或优化预训练网络的某些层时,您需要了解这一点.
TensorFlow的minimize方法采用可选参数var_list,即通过反向传播调整的变量列表.
如果未指定var_list,则优化程序可以调整图形中的任何TF变量.当您指定一些变量时var_list,TF会保持所有其他变量不变.
这是jonbruner和他的合作者使用的脚本示例.
tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)
Run Code Online (Sandbox Code Playgroud)
这将找到之前定义的变量名称中包含"g_"的所有变量,将它们放入列表中,并对它们运行ADAM优化器.
您可以在Quora上找到相关答案
为了从可训练变量列表中删除变量,您可以首先通过以下方式访问该集合:
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
其中trainable_collection包含对可训练变量集合的引用.如果从该列表中弹出元素,例如trainable_collection.pop(0),您将从可训练变量中删除相应的变量,因此不会训练此变量.
虽然这有效pop,但我仍然在努力寻找正确使用remove正确参数的方法,因此我们不依赖于变量的索引.
编辑:鉴于您在图表中有变量的名称(您可以通过检查图形protobuf或使用Tensorboard更容易获得),您可以使用它来遍历可训练变量列表然后删除来自可训练集合的变量.示例:假设我希望对名称"batch_normalization/gamma:0"和"batch_normalization/beta:0" NOT的变量进行训练,但它们已添加到TRAINABLE_VARIABLES集合中.我能做的是:`
#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
#uses the attribute 'name' of the variable
if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
variables_to_remove.append(vari)
for rem in variables_to_remove:
trainable_collection.remove(rem)
Run Code Online (Sandbox Code Playgroud)
`这将成功从集合中删除这两个变量,它们将不再受到训练.
| 归档时间: |
|
| 查看次数: |
22967 次 |
| 最近记录: |