Tensorflow:tf.get_collection不返回范围中的变量

Mat*_*per 6 python tensorflow

我试图让所有的变量在变量范围,如解释在这里.但是,tf.get_collection(tf.GraphKeys.VARIABLES, scope='my_scope')即使该范围中存在变量,该行也会返回一个空列表.

这是一些示例代码:

import tensorflow as tf

with tf.variable_scope('my_scope'):
    a = tf.Variable(0)
print tf.get_collection(tf.GraphKeys.VARIABLES, scope='my_scope')
Run Code Online (Sandbox Code Playgroud)

打印[].

如何获取声明的变量'my_scope'

mrr*_*rry 11

tf.GraphKeys.VARIABLES自TensorFlow 0.12起,该集合名称已被弃用.使用tf.GraphKeys.GLOBAL_VARIABLES将给出预期的结果:

with tf.variable_scope('my_scope'):
    a = tf.Variable(0)
print tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_scope')
# ==> '[<tensorflow.python.ops.variables.Variable object at 0x7f33f67ebbd0>]'
Run Code Online (Sandbox Code Playgroud)

  • @CharlieParker tf.get_collection()方法假定所有内容都在单个默认图形中(即tf.get_default_graph()返回的图形)。如果您有多个图形对象,并且感兴趣的图形称为`g`,则可以使用[`g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope ='my_scope')`](https://www.tensorflow。 org / api_docs / python / tf / Graph#get_collection)。但是,如果这不起作用,则可能是错误的`tf.Graph`对象。用更多的代码来解决一个单独的问题可能是最容易的。 (2认同)