我正在尝试更新tf.Variable内部a tf.while_loop(),使用tf.scatter_update().但是,结果是初始值而不是更新的值.以下是我要做的示例代码:
from __future__ import print_function
import tensorflow as tf
def cond(sequence_len, step):
return tf.less(step,sequence_len)
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin,1,step,use_locking=None)
tf.get_variable_scope().reuse_variables()
return (sequence_len, step+1)
with tf.Graph().as_default():
sess = tf.Session()
step = tf.constant(0)
sequence_len = tf.constant(10)
_,step, = tf.while_loop(cond,
body,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)
begin = tf.get_variable("begin",[3],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([begin,step]))
Run Code Online (Sandbox Code Playgroud)
结果是:[array([0, 0, 0], dtype=int32), 10].但是,我认为结果应该是[0, 0, 10].我在这里做错了吗?
这里的问题是循环体中没有任何东西取决于你的tf.scatter_update()操作,所以它永远不会被执行.使其工作的最简单方法是在返回值的更新上添加控件依赖项:
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin, 1, step, use_locking=None)
tf.get_variable_scope().reuse_variables()
with tf.control_dependencies([begin]):
return (sequence_len, step+1)
Run Code Online (Sandbox Code Playgroud)
请注意,此问题并非TensorFlow中的循环所特有.如果您刚刚定义了一个tf.scatter_update()被begin调用的操作但是调用sess.run()它,或者依赖于它的操作,那么更新将不会发生.当你使用时,tf.while_loop()无法直接运行循环体中定义的操作,因此获得副作用的最简单方法是添加控件依赖项.
请注意,最终结果是[0, 9, 0]:每次迭代都会将当前步骤分配给begin[1],并且在最后一次迭代中,当前步骤的值为9(条件为false时step == 10).
| 归档时间: |
|
| 查看次数: |
1484 次 |
| 最近记录: |