为什么 tf.Print() 不起作用?

i s*_*eal 6 python python-3.x tensorflow

我有这个代码片段:

import tensorflow as tf
import numpy as np

    # batch x time x events x dim
batch = 2
time = 3
events = 4
tensor = np.random.rand(batch, time, events)

tensor[0][0][2] = 0
tensor[0][0][3] = 0

tensor[0][1][3] = 0

tensor[0][2][1] = 0
tensor[0][2][2] = 0
tensor[0][2][3] = 0

tensor[1][0][3] = 0

def cum_sum(prev, cur):
    non_zeros = tf.equal(cur, 0.)
    tf.Print(non_zeros, [non_zeros], "message ")
    tf.Print(cur, [cur])
    return cur

elems = tf.constant([1,2,3],dtype=tf.int64)
#alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
cum_sum_ = tf.scan(cum_sum, tensor)

s = tf.Session()

s.run(cum_sum_)
Run Code Online (Sandbox Code Playgroud)

tf.Print在传递给 的函数中有两个语句tf.scan,但是当我运行累积总和时,我没有得到任何打印语句。难道我做错了什么?

Ign*_*ier 7

tf.Print 不是这样工作的。打印节点需要进入图中才能执行。我强烈建议您查看教程以了解如何使用它。

如果您有任何问题随时问。