设置 torch.gather(...) 调用的结果

Luc*_*ter 6 python indexing pytorch tensor

我有一个形状为 n x m 的 2D pytorch 张量。我想使用索引列表(可以使用 torch.gather 完成)对第二个维度进行索引,然后还将新值设置为索引的结果。

例子:

data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])
Run Code Online (Sandbox Code Playgroud)

我想选择每行指定的索引(但随后[1,5,7]也将这些值设置为另一个数字 - 例如 42

我可以通过执行以下操作来逐行选择所需的列:

data.gather(1, indices)
tensor([[1],
        [5],
        [7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather 
                                # does not use the same storage as the original tensor
Run Code Online (Sandbox Code Playgroud)

这很好,但我现在想更改这些值,并且更改也会影响data张量。

我可以用它来做我想要实现的事情,但它似乎非常不Pythonic:

data.gather(1, indices)
tensor([[1],
        [5],
        [7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather 
                                # does not use the same storage as the original tensor
Run Code Online (Sandbox Code Playgroud)

关于如何更优雅地做到这一点有什么提示吗?

Iva*_*van 4

您正在寻找的是torch.scatter_带有value选项的东西。

\n
\n

Tensor.scatter_(dim, index, src, reduce=None) \xe2\x86\x92 Tensor
\n将张量中的所有值src写入张量self中指定的索引处index。对于 中的每个值src,其输出index由 src 中的索引dimension != dim和 的相应值索引指定dimension = dim

\n

使用 2D 张量作为输入 和dim=1,运算为:\n
self[i][index[i][j]] = src[i][j]

\n
\n

但没有提到 value 参数......

\n
\n

使用value=42、 和dim=1,这将对数据产生以下影响:

\n
data[i][index[i][j]] = 42\n
Run Code Online (Sandbox Code Playgroud)\n

这里就地应用:

\n
>>> data.scatter_(index=indices, dim=1, value=42)\n>>> data\ntensor([[ 0, 42,  2],\n        [ 3,  4, 42],\n        [ 6, 42,  8]])\n
Run Code Online (Sandbox Code Playgroud)\n