如何在TensorFlow中查找第一个匹配元素的索引

Igo*_*kov 9 tensorflow

我正在寻找一种TensorFlow方法来实现类似于Python的list.index()函数.

给定矩阵和要查找的值,我想知道矩阵每行中第一次出现的值.

例如,

m is a <batch_size, 100> matrix of integers
val = 23

result = [0] * batch_size
for i, row_elems in enumerate(m):
  result[i] = row_elems.index(val)
Run Code Online (Sandbox Code Playgroud)

我不能假设'val'每行只出现一次,否则我会用tf.argmax(m == val)实现它.在我的例子中,重要的是获得第一次出现'val' 的索引,而不是任何索引.

Jen*_*nny 11

看起来tf.argmaxnp.argmax(根据测试),当有多次出现最大值时将返回第一个索引.你可以用它tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)来得到你想要的东西.但是,tf.argmax在多次出现最大值的情况下,当前行为未定义.

如果您担心未定义的行为,您可以应用@Igor Tsvetkov建议tf.argmin的返回值tf.where.例如,

# test with tensorflow r1.0
import tensorflow as tf

val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0  ,   0, val,   0, val],
          [val,   0, val, val,   0],
          [0  , val,   0,   0,   0]]

tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])

with tf.Session() as sess:
    print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]
Run Code Online (Sandbox Code Playgroud)

请注意,当某行包含no时tf.segment_min会引发.在您的代码中,如果不包含,也会引发异常.InvalidArgumentErrorvalrow_elems.index(val)row_elemsval