如何在TensorFlow中使用tf.get_variable和numpy值初始化变量?

Pin*_*hio 22 python numpy tensorflow

我想用numpy值初始化我网络上的一些变量.为了这个例子考虑:

init=np.random.rand(1,2)
tf.get_variable('var_name',initializer=init)
Run Code Online (Sandbox Code Playgroud)

当我这样做时,我收到一个错误:

ValueError: Shape of a new variable (var_name) must be fully defined, but instead was <unknown>.
Run Code Online (Sandbox Code Playgroud)

为什么我得到那个错误?

为了尝试修复它,我尝试了:

tf.get_variable('var_name',initializer=init, shape=[1,2])
Run Code Online (Sandbox Code Playgroud)

这产生了一个更奇怪的错误:

TypeError: 'numpy.ndarray' object is not callable
Run Code Online (Sandbox Code Playgroud)

我尝试阅读文档和示例,但它并没有真正帮助.

是否无法使用TensorFlow中的get_variable方法使用numpy数组初始化变量?

kev*_*man 37

以下作品:

init = tf.constant(np.random.rand(1, 2))
tf.get_variable('var_name', initializer=init)
Run Code Online (Sandbox Code Playgroud)

文档get_variable确实有点缺乏.仅供参考,initializer参数必须是TensorFlow Tensor对象(可以通过调用您的案例中tf.constantnumpy值来构造),或者是带有两个参数的"可调用"对象,shape以及dtype值的形状和数据类型它应该回归.同样,在您的情况下,您可以编写以下内容,以防您想使用"可调用"机制:

init = lambda shape, dtype: np.random.rand(*shape)
tf.tf.get_variable('var_name', initializer=init, shape=[1, 2])
Run Code Online (Sandbox Code Playgroud)

  • [This](http://stackoverflow.com/questions/111234/what-is-a-callable-in-python)是您问题的绝佳答案. (2认同)

Nez*_*zha 10

@keveman回答得很好,并且作为补充,有使用 tf.get_variable('var_name',initializer = init),tensorflow文档确实提供了一个全面的例子.

import numpy as np
import tensorflow as tf

value = [0, 1, 2, 3, 4, 5, 6, 7]
# value = np.array(value)
# value = value.reshape([2, 4])
init = tf.constant_initializer(value)

print('fitting shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [2, 4], initializer = init)
    x.initializer.run()
    print(x.eval())

    fitting shape :
[[0.  1.  2.  3.]
[4.  5.  6.  7.]]

print('larger shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [3, 4], initializer = init)
    x.initializer.run()
    print(x.eval())

    larger shape :
[[0.  1.  2.  3.]
[4.  5.  6.  7.]
[7.  7.  7.  7.]]

print('smaller shape:')
tf.reset_default_graph()
with tf.Session() :
    x = tf.get_variable('x', shape = [2, 3], initializer = init)

    * <b>`ValueError`< / b > : Too many elements provided.Needed at most 6, but received 8
Run Code Online (Sandbox Code Playgroud)

https://www.tensorflow.org/api_docs/python/tf/constant_initializer