我正在寻找一种TensorFlow方法来实现类似于Python的list.index()函数.
给定矩阵和要查找的值,我想知道矩阵每行中第一次出现的值.
例如,
m is a <batch_size, 100> matrix of integers
val = 23
result = [0] * batch_size
for i, row_elems in enumerate(m):
result[i] = row_elems.index(val)
Run Code Online (Sandbox Code Playgroud)
我不能假设'val'每行只出现一次,否则我会用tf.argmax(m == val)实现它.在我的例子中,重要的是获得第一次出现'val' 的索引,而不是任何索引.
Jen*_*nny 11
看起来tf.argmax
像np.argmax
(根据测试),当有多次出现最大值时将返回第一个索引.你可以用它tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1)
来得到你想要的东西.但是,tf.argmax
在多次出现最大值的情况下,当前行为未定义.
如果您担心未定义的行为,您可以应用@Igor Tsvetkov建议tf.argmin
的返回值tf.where
.例如,
# test with tensorflow r1.0
import tensorflow as tf
val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0 , 0, val, 0, val],
[val, 0, val, val, 0],
[0 , val, 0, 0, 0]]
tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])
with tf.Session() as sess:
print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]
Run Code Online (Sandbox Code Playgroud)
请注意,当某行包含no时tf.segment_min
会引发.在您的代码中,如果不包含,也会引发异常.InvalidArgumentError
val
row_elems.index(val)
row_elems
val
归档时间: |
|
查看次数: |
7569 次 |
最近记录: |