Tensorflow while循环:处理列表

web*_*ash 10 while-loop python-2.7 tensorflow

import tensorflow as tf

array = tf.Variable(tf.random_normal([10]))
i = tf.constant(0)
l = []

def cond(i,l):
   return i < 10

def body(i,l):
   temp = tf.gather(array,i)
   l.append(temp)
   return i+1,l

index,list_vals = tf.while_loop(cond, body, [i,l])
Run Code Online (Sandbox Code Playgroud)

我想以与上面代码中描述的类似方式处理张量数组.在while循环的主体中,我想逐个元素地处理数组以应用一些函数.为了演示,我给了一个小代码片段.但是,它给出了如下错误消息.

ValueError: Number of inputs and outputs of body must match loop_vars: 1, 2
Run Code Online (Sandbox Code Playgroud)

任何帮助解决这个问题表示赞赏.

谢谢

syg*_*ygi 12

引用文档:

loop_vars是一个(可能是嵌套的)元组,名称元组或张量传递列表,传递给两者condbody

你不能将常规python数组作为张量传递.你能做的是:

i = tf.constant(0)
l = tf.Variable([])

def body(i, l):                                               
    temp = tf.gather(array,i)
    l = tf.concat([l, [temp]], 0)
    return i+1, l

index, list_vals = tf.while_loop(cond, body, [i, l],
                                 shape_invariants=[i.get_shape(),
                                                   tf.TensorShape([None])])
Run Code Online (Sandbox Code Playgroud)

形状不变量在那里,因为通常tf.while_loop期望内部的张量的形状,而循环不会改变.

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(list_vals)
Out: array([-0.38367489, -1.76104736,  0.26266089, -2.74720812,  1.48196387,
            -0.23357525, -1.07429159, -1.79547787, -0.74316853,  0.15982138], 
           dtype=float32)
Run Code Online (Sandbox Code Playgroud)


sud*_*oer 5

TF提供了一个TensorArray来处理这种情况。从文档中

类包装动态大小的,按时间划分的一次写入Tensor数组。

此类应与动态迭代原语(例如while_loop和)一起使用map_fn。它通过特殊的“流”控制流相关性支持梯度反向传播。

这是一个例子

import tensorflow as tf

array = tf.Variable(tf.random_normal([10]))
step = tf.constant(0)
output = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

def cond(step, output):
    return step < 10

def body(step, output):
    output = output.write(step, tf.gather(array, step))
    return step + 1, output

_, final_output = tf.while_loop(cond, body, loop_vars=[step, output])

final_output = final_output.stack()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(final_output))
Run Code Online (Sandbox Code Playgroud)