使用TensorFlow Python API时,我创建了一个变量(没有name在构造函数中指定它),并且它的name属性具有值"Variable_23:0".当我尝试使用这个变量时tf.get_variable("Variable23"),"Variable_23_1:0"会创建一个名为的新变量.如何正确选择"Variable_23"而不是创建新的?
我想要做的是按名称选择变量,并重新初始化它,以便我可以微调权重.
Min*_*ark 36
该get_variable()函数创建一个新变量或返回之前创建的变量get_variable().它不会返回使用创建的变量tf.Variable().这是一个简单的例子:
>>> with tf.variable_scope("foo"):
... bar1 = tf.get_variable("bar", (2,3)) # create
...
>>> with tf.variable_scope("foo", reuse=True):
... bar2 = tf.get_variable("bar") # reuse
...
>>> with tf.variable_scope("", reuse=True): # root variable scope
... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
...
>>> (bar1 is bar2) and (bar2 is bar3)
True
Run Code Online (Sandbox Code Playgroud)
如果您没有使用创建变量tf.get_variable(),则有几个选项.首先,您可以使用tf.global_variables()(如@mrry建议的那样):
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True
Run Code Online (Sandbox Code Playgroud)
或者您可以这样使用tf.get_collection():
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True
Run Code Online (Sandbox Code Playgroud)
编辑
您还可以使用get_tensor_by_name():
>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal
bar2 in value.
Run Code Online (Sandbox Code Playgroud)
回想一下张量是一个操作的输出.它与操作同名,另外:0.如果操作有多个输出,它们具有相同的名称作为操作加:0,:1,:2,等等.
mrr*_*rry 35
通过名称获取变量的最简单方法是在tf.global_variables()集合中搜索它:
var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]
Run Code Online (Sandbox Code Playgroud)
这适用于现有变量的临时重用.当您想要在模型的多个部分之间共享变量时,更加结构化的方法将在" 共享变量"教程中介绍.
| 归档时间: |
|
| 查看次数: |
40633 次 |
| 最近记录: |