TensorFlow获取特定列的每一行元素

Kas*_*yap 10 python numpy tensorflow

如果A是这样的TensorFlow变量

A = tf.Variable([[1, 2], [3, 4]])
Run Code Online (Sandbox Code Playgroud)

并且index是另一个变量

index = tf.Variable([0, 1])
Run Code Online (Sandbox Code Playgroud)

我想使用此索引来选择每行中的列.在这种情况下,第一行的第0项和第二行的第1项.

如果A是Numpy数组,那么为了得到索引中提到的相应行的列,我们可以做到

x = A[np.arange(A.shape[0]), index]
Run Code Online (Sandbox Code Playgroud)

结果就是

[1, 4]
Run Code Online (Sandbox Code Playgroud)

什么是TensorFlow等效操作/操作?我知道TensorFlow不支持很多索引操作.如果无法直接完成,将会有什么工作?

小智 7

您可以使用一种热方法来创建 one_hot 数组并将其用作布尔掩码来选择您想要的索引。

A = tf.Variable([[1, 2], [3, 4]])
index = tf.Variable([0, 1])

one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool)
output = tf.boolean_mask(A, one_hot_mask)
Run Code Online (Sandbox Code Playgroud)


Max*_*sov 6

您可以使用行索引扩展列索引,然后使用collect_nd:

import tensorflow as tf

A = tf.constant([[1, 2], [3, 4]])
indices = tf.constant([1, 0])

# prepare row indices
row_indices = tf.range(tf.shape(indices)[0])

# zip row indices with column indices
full_indices = tf.stack([row_indices, indices], axis=1)

# retrieve values by indices
S = tf.gather_nd(A, full_indices)

session = tf.InteractiveSession()
session.run(S)
Run Code Online (Sandbox Code Playgroud)


Kas*_*yap 3

经过一段时间的摸索。我发现两个可能有用的功能。

其一是,tf.gather_nd()如果您可以生成以下形式的张量[[0, 0], [1, 1]],则这可能会很有用,从而您可以执行以下操作:

index = tf.constant([[0, 0], [1, 1]])

tf.gather_nd(A, index)

如果由于某种原因您无法生成以下形式的向量[[0, 0], [1, 1]](我无法生成此向量,因为我的情况下的行数取决于占位符),那么我发现的解决方法是使用tf.py_func(). 这是有关如何完成此操作的示例代码

import tensorflow as tf 
import numpy as np 

def index_along_every_row(array, index):
    N, _ = array.shape 
    return array[np.arange(N), index]

a = tf.Variable([[1, 2], [3, 4]], dtype=tf.int32)
index = tf.Variable([0, 1], dtype=tf.int32)
a_slice_op = tf.py_func(index_along_every_row, [a, index], [tf.int32])[0]
session = tf.InteractiveSession()

a.initializer.run()
index.initializer.run()
a_slice = a_slice_op.eval() 
Run Code Online (Sandbox Code Playgroud)

a_slice将是一个 numpy 数组[1, 4]