Tensorflow 获取张量中值的索引

Mun*_*ong 4 python matrix tensorflow

给定一个矩阵和向量?我想在矩阵的相应行中找到值的索引。

m = tf.constant([[0, 2, 1],[2, 0, 1]])  # matrix
y = tf.constant([1,2])  # values whose indices should be found
Run Code Online (Sandbox Code Playgroud)

理想的输出是 [2,0],因为 y 的第一个值 1 位于 m 的第一个向量的索引 2 处。y 的第二个值 2 位于 m 的第二个向量的索引 0 处。

Mun*_*ong 6

我只找到一种解决方案。但是不知道有没有更好的。

m = tf.constant([[0, 2, 1],[2, 0, 1]])  # matrix
y = tf.constant([1,2])  # values whose indices should be found
y = tf.reshape(y, (y.shape[0], 1))  # [[1], [2]]
cols = tf.where(tf.equal(m, y))[:,-1]  # [2,0]

init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    print(sess.run(cols))
Run Code Online (Sandbox Code Playgroud)

以上输出: [2, 0]