小编Sau*_*wal的帖子

在 tf.where() 中使用 == 条件的问题

我最近开始使用 tensorflow 并尝试使用 tf.where() 函数。我注意到每当我使用“==”条件时它都会抛出错误。例如,当我尝试以下操作时:

t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])

t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)

t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
    print(sess.run([t3]))
Run Code Online (Sandbox Code Playgroud)

它抛出以下错误:

InvalidArgumentError:WhereOp:未处理的输入维度:0

谁能解释一下这里可能有什么错误?提前致谢!

tensorflow

3
推荐指数
1
解决办法
1774
查看次数

标签 统计

tensorflow ×1