我正在尝试熟悉 TensorFlow,但我不确定占位符、变量等。为了让事情变得简单,我尝试创建一个非常简单的计算 - 一个占位符和一个变量,该变量只是占位符乘以 2。
我把所有东西都放在一个函数中,像这样:
import tensorflow as tf
def try_variable(value):
x = tf.placeholder(tf.float64, name='x')
v = tf.Variable(x * 2, name='v', validate_shape=False)
with tf.Session() as session:
init = tf.global_variables_initializer()
session.run(init, feed_dict={x: value})
return session.run(v)
Run Code Online (Sandbox Code Playgroud)
然后我调用函数:
print(try_variable(80))
Run Code Online (Sandbox Code Playgroud)
确实输出是160。
但是当我再次调用它时:
print(try_variable(80))
Run Code Online (Sandbox Code Playgroud)
我收到一个错误:
InvalidArgumentError:您必须使用 dtype double 为占位符张量“x”提供一个值
我错过了什么?
现在,您每次调用该函数时都会创建一个新变量和占位符,因此第二次调用该try_variable函数时,您实际上拥有 2 个占位符和 2 个 TensorFlow 变量!x, x_1, v, v_1.
因此,在您第二次运行 init 操作时,您仅为占位符提供初始值,该占位符x_1现在绑定到 python 变量x。
如果要打印当前图中所有张量的名称,可以调用
print [n.name for n in tf.get_default_graph().as_graph_def().node]
Run Code Online (Sandbox Code Playgroud)
如果您每次调用该函数时仍想创建 2 个新张量,一种选择是每次调用该函数时使用该命令重置默认图形 tf.reset_default_graph()
- 极不推荐这样做。
| 归档时间: |
|
| 查看次数: |
308 次 |
| 最近记录: |