你如何使用TensorFlow Graphkeys来获得所有权重?

SSt*_*ors 4 python machine-learning tensorflow

Tensorflow定义了预设的集合,如下所示:https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections

我目前正在使用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)所有变量[*这些变量命名; 如果他们不是那么他们就不会被展示,即使他们存在].

同样,我希望tf.get_collection(tf.GraphKeys.WEIGHTS)输出一个权重列表,但它是一个空数组.这也适用于GraphKeys.BIASES.ACTIVATIONS.

这里发生了什么?

在我看来,这里似乎有两种可能性.首先,它们实际上从未自动定义,只是推荐的集合名称.其次,我的网络非常破碎,但似乎并非如此.

有人有这方面的经验吗?

Ler*_*ang 5

默认情况下,所有变量都绑定到tf.GraphKeys.GLOBAL_VARIABLES集合.方便的方法是将每个权重设置为集合tf.GraphKeys.WEIGHTS,如下所示:

In [2]: w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
In [3]: w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)

然后你可以通过以下方式获取它们:

tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
Run Code Online (Sandbox Code Playgroud)

这是权重:

[<tf.Variable 'Variable:0' shape=(3,) dtype=float32_ref>,
 <tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]
Run Code Online (Sandbox Code Playgroud)