Tensorflow:tf.get_variable如何工作?

fir*_*hin 16 tensorflow

我已经从这个问题中读到了关于tf.get_variable的内容,还有一些来自tensorflow网站上的文档.但是,我仍然不清楚,无法在网上找到答案.

tf.get_variable如何工作?例如:

var1 = tf.Variable(3.,dtype=float64)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
Run Code Online (Sandbox Code Playgroud)

这是否意味着var2另一个初始化类似于var1的变量?或者var2var1的别名(我尝试过它似乎没有)?

var1var2如何相关?

当我们得到的变量不存在时,如何构造变量?

nes*_*uno 28

tf.get_variable(name)在张量流图中创建一个名为name(或添加_,如果name已存在于当前范围中)的新变量.

在您的示例中,您将创建一个名为的python变量var1.

**Tensorflow图中该变量的名称不是**var1,而是Variable:0.

您定义的每个节点都有自己可以指定的名称,或者让tensorflow给出一个默认(并且始终不同)的名称.您可以看到name访问namepython变量属性的值.(即print(var1.name)).

在第二行,您将定义一个Python变量, var2其张流图中的名称为var1.

剧本

import tensorflow as tf

var1 = tf.Variable(3.,dtype=tf.float64)
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)
Run Code Online (Sandbox Code Playgroud)

事实上印刷品:

Variable:0
var1:0
Run Code Online (Sandbox Code Playgroud)

相反,如果你想要定义一个在张量var1流图中调用的变量(节点),然后获得对该节点的引用,你不能简单地使用tf.get_variable("var1")它,因为它将创建一个新的不同变量valled var1_1.

这个脚本

var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)
Run Code Online (Sandbox Code Playgroud)

打印:

var1:0
var1_1:0
Run Code Online (Sandbox Code Playgroud)

如果要创建对节点的引用var1,首先要:

  1. 不得不更换tf.Variabletf.get_variable.创建的变量tf.Variable不能共享,而后者可以.

  2. 知道了什么scopevar1是并允许reuse声明引用时范围.

查看代码是更好的理解方式

import tensorflow as tf

#var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
var1 = tf.get_variable(initializer=tf.constant_initializer(3.), dtype=tf.float64, name="var1", shape=())
current_scope = tf.contrib.framework.get_name_scope()
print(var1.name)
with tf.variable_scope(current_scope, reuse=True):
    var2 = tf.get_variable("var1",[],dtype=tf.float64)
    print(var2.name)
Run Code Online (Sandbox Code Playgroud)

输出:

var1:0
var1:0
Run Code Online (Sandbox Code Playgroud)