我最近开始使用 tensorflow 并尝试使用 tf.where() 函数。我注意到每当我使用“==”条件时它都会抛出错误。例如,当我尝试以下操作时:
t = tf.constant([[1, 2, 3],
[4, 5, 6]])
t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)
t3_shape= tf.shape(t)[0]
with tf.Session() as sess:
print(sess.run([t3]))
Run Code Online (Sandbox Code Playgroud)
它抛出以下错误:
InvalidArgumentError:WhereOp:未处理的输入维度:0
谁能解释一下这里可能有什么错误?提前致谢!
您需要tf.equal进行逐元素比较:
t2 = tf.where(tf.equal(t, 2))
Run Code Online (Sandbox Code Playgroud)
t = tf.constant([[1, 2, 3],
[4, 5, 6]])
t2 = tf.where(tf.equal(t, 2))
t3 = tf.gather_nd(t,t2)
t3_shape= tf.shape(t)[0]
with tf.Session() as sess:
print(sess.run([t3]))
# [array([2])]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1774 次 |
| 最近记录: |