在 tf.where() 中使用 == 条件的问题

Sau*_*wal 3 tensorflow

我最近开始使用 tensorflow 并尝试使用 tf.where() 函数。我注意到每当我使用“==”条件时它都会抛出错误。例如,当我尝试以下操作时:

t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])

t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)

t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
    print(sess.run([t3]))
Run Code Online (Sandbox Code Playgroud)

它抛出以下错误:

InvalidArgumentError:WhereOp:未处理的输入维度:0

谁能解释一下这里可能有什么错误?提前致谢!

Psi*_*dom 6

您需要tf.equal进行逐元素比较:

t2 = tf.where(tf.equal(t, 2))
Run Code Online (Sandbox Code Playgroud)
t = tf.constant([[1, 2, 3],
                 [4, 5, 6]])

t2 = tf.where(tf.equal(t, 2))
t3 = tf.gather_nd(t,t2)   
t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
    print(sess.run([t3]))

# [array([2])]
Run Code Online (Sandbox Code Playgroud)