选择/屏蔽每行中的不同列索引

Dep*_*ify 8 pytorch

在pytorch中我有一个多维张量,称之为X

X = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], ...]
Run Code Online (Sandbox Code Playgroud)

现在我想为每一行选择不同的列索引,如下所示

indices = [[0], [1], [0], [2], ...]
# now I expect following values to be returned:
[[1], [5], [7], [12], ...]
Run Code Online (Sandbox Code Playgroud)

我也想实现相反的目标,以便对于给定的指数我得到

[[2, 3], [4, 6], [8, 9], [10, 11]]
Run Code Online (Sandbox Code Playgroud)

有没有一种“简单”的方法可以在没有 for 循环的情况下实现这一目标?如果有任何想法,我将不胜感激。

Dep*_*ify 16

事实上,该torch.gather函数正是执行此操作。

例如

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
indices = torch.tensor([[0], [1], [0], [2]])
a.gather(1, indices)
Run Code Online (Sandbox Code Playgroud)

会准确返回

tensor([[ 1],
        [ 5],
        [ 7],
        [12]])
Run Code Online (Sandbox Code Playgroud)

我不再需要相反的东西,但为此我建议只创建一个包含所有值的掩码,然后将“收集”张量的相应索引设置为 0,或者只是创建一个包含相应相反键的新“收集”张量。例如:

indices_opposite = [np.setdiff1d(np.arange(a.size(1)), i) for i in indices.numpy()]
Run Code Online (Sandbox Code Playgroud)