我正在寻找一种TensorFlow方法来实现类似于Python的list.index()函数.
给定矩阵和要查找的值,我想知道矩阵每行中第一次出现的值.
例如,
m is a <batch_size, 100> matrix of integers
val = 23
result = [0] * batch_size
for i, row_elems in enumerate(m):
result[i] = row_elems.index(val)
Run Code Online (Sandbox Code Playgroud)
我不能假设'val'每行只出现一次,否则我会用tf.argmax(m == val)实现它.在我的例子中,重要的是获得第一次出现'val' 的索引,而不是任何索引.
tensorflow ×1