Tensorflow定义了此API:
tf.local_variables()
返回使用的所有变量
collection=[LOCAL_VARIABLES]
.返回:
本地Variable对象的列表.
TensorFlow中的局部变量究竟是什么?有人能举个例子吗?
Sal*_*ali 24
简短回答:TF中的局部变量是用它创建的任何变量collections=[tf.GraphKeys.LOCAL_VARIABLES]
.例如:
e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])
Run Code Online (Sandbox Code Playgroud)
LOCAL_VARIABLES:每台机器本地的Variable对象的子集.通常用于临时变量,如计数器.注意:使用tf.contrib.framework.local_variable添加到此集合.
它们通常不会保存/恢复到检查点,并用于临时值或中间值.
答案很长:这对我来说也是一个混乱的根源.在开始时我认为局部变量在几乎所有编程语言中都与局部变量相同,但它不是一回事:
import tensorflow as tf
def some_func():
z = tf.Variable(1, name='var_z')
a = tf.Variable(1, name='var_a')
b = tf.get_variable('var_b', 2)
with tf.name_scope('aaa'):
c = tf.Variable(3, name='var_c')
with tf.variable_scope('bbb'):
d = tf.Variable(3, name='var_d')
some_func()
some_func()
print [str(i.name) for i in tf.global_variables()]
print [str(i.name) for i in tf.local_variables()]
Run Code Online (Sandbox Code Playgroud)
无论我尝试什么,我总是只收到全局变量:
['var_a:0', 'var_b:0', 'aaa/var_c:0', 'bbb/var_d:0', 'var_z:0', 'var_z_1:0']
[]
Run Code Online (Sandbox Code Playgroud)
文档tf.local_variables
没有提供很多细节:
局部变量 - 每个过程变量,通常不保存/恢复到检查点并用于临时值或中间值.例如,它们可以用作度量计算的计数器或本机读取数据的时期数.local_variable()自动将新变量添加到GraphKeys.LOCAL_VARIABLES.此便捷函数返回该集合的内容.
但是在tf.Variable类中读取init方法的文档时,我发现在创建变量时,您可以通过分配列表来提供您想要的变量类型collections
.
可能的集合元素列表在这里.所以要创建一个局部变量,你需要做这样的事情.您将在以下列表中看到它local_variables
:
e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])
print [str(i.name) for i in tf.local_variables()]
Run Code Online (Sandbox Code Playgroud)
Yar*_*tov 18
它与常规变量相同,但它与default(GraphKeys.VARIABLES
)不同.该保护程序使用该集合初始化要保存的默认变量列表,因此具有local
指定具有默认情况下不保存该变量的效果.
我只看到一个在代码库中使用它的地方,即 limit_epochs
with ops.name_scope(name, "limit_epochs", [tensor]) as name:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
epochs = variables.Variable(
zero64, name="epochs", trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
Run Code Online (Sandbox Code Playgroud)
我认为,这里需要了解TensorFlow集合。
TensorFlow提供了集合,这些集合是张量或其他对象的命名列表,例如 tf.Variable
实例)的。
以下是内置集合:
tf.GraphKeys.GLOBAL_VARIABLES #=> 'variables'
tf.GraphKeys.LOCAL_VARIABLES #=> 'local_variables'
tf.GraphKeys.MODEL_VARIABLES #=> 'model_variables'
tf.GraphKeys.TRAINABLE_VARIABLES #=> 'trainable_variables'
Run Code Online (Sandbox Code Playgroud)
通常,在创建变量时,可以通过将其作为传递给collections
参数的集合之一来显式传递该集合,从而将其添加到给定的集合中。
从理论上讲,变量可以是内置或自定义集合的任意组合。但是,内置集合用于特定目的:
tf.GraphKeys.GLOBAL_VARIABLES
:
Variable()
构造函数或get_variable()
自动添加新的变数图表收集GraphKeys.GLOBAL_VARIABLES
,除非collections
参数显式传递,不包括GLOBAL_VARIABLE
。 tf.global_variables()
以获取更多详细信息。 tf.GraphKeys.TRAINABLE_VARIABLES
:
trainable=True
(这是默认行为),Variable()
构造函数将get_variable()
自动向该图集合添加新变量。但是,当然,您可以使用collections
参数将变量添加到任何所需的集合中。 tf.trainable_variables()
以获取更多详细信息。 tf.GraphKeys.LOCAL_VARIABLES
:
tf.contrib.framework.local_variable()
添加到此收藏集。但是,当然,您可以使用collections
参数将变量添加到任何所需的集合中。 tf.local_variables()
以获取更多详细信息。 tf.GraphKeys.MODEL_VARIABLES
:
tf.contrib.framework.model_variable()
添加到此收藏集。但是,当然,您可以使用collections
参数将变量添加到任何所需的集合中。 tf.model_variables()
以获取更多详细信息。 您也可以使用自己的收藏集。任何字符串都是有效的集合名称,无需显式创建集合。要在创建变量后将变量(或任何其他对象)添加到集合中,请调用tf.add_to_collection()
。
例如,
tf.__version__ #=> '1.9.0'
# initializing using a Tensor
my_variable01 = tf.get_variable("var01", dtype=tf.int32, initializer=tf.constant([23, 42]))
# initializing using a convenient initializer
my_variable02 = tf.get_variable("var02", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.zeros_initializer)
my_variable03 = tf.get_variable("var03", dtype=tf.int32, initializer=tf.constant([1, 2]), trainable=None)
my_variable04 = tf.get_variable("var04", dtype=tf.int32, initializer=tf.constant([3, 4]), trainable=False)
my_variable05 = tf.get_variable("var05", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.ones_initializer, trainable=True)
my_variable06 = tf.get_variable("var06", dtype=tf.int32, initializer=tf.constant([5, 6]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=None)
my_variable07 = tf.get_variable("var07", dtype=tf.int32, initializer=tf.constant([7, 8]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=True)
my_variable08 = tf.get_variable("var08", dtype=tf.int32, initializer=tf.constant(1), collections=[tf.GraphKeys.MODEL_VARIABLES], trainable=None)
my_variable09 = tf.get_variable("var09", dtype=tf.int32, initializer=tf.constant(2), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES, "my_collectio
n"])
my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)
[var.name for var in tf.global_variables()] #=> ['var01:0', 'var02:0', 'var03:0', 'var04:0', 'var05:0', 'var09:0']
[var.name for var in tf.local_variables()] #=> ['var06:0', 'var07:0', 'var09:0']
[var.name for var in tf.trainable_variables()] #=> ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']
[var.name for var in tf.model_variables()] #=> ['var08:0', 'var09:0']
[var.name for var in tf.get_collection("trainable_variables")] #=> ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']
[var.name for var in tf.get_collection("my_collection")] #=> ['var09:0', 'var10:0']
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
9992 次 |
最近记录: |