Luc*_*ter 6 python indexing pytorch tensor
我有一个形状为 n x m 的 2D pytorch 张量。我想使用索引列表(可以使用 torch.gather 完成)对第二个维度进行索引,然后还将新值设置为索引的结果。
例子:
data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
Run Code Online (Sandbox Code Playgroud)
我想选择每行指定的索引(但随后[1,5,7]也将这些值设置为另一个数字 - 例如 42
我可以通过执行以下操作来逐行选择所需的列:
data.gather(1, indices)
tensor([[1],
[5],
[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather
# does not use the same storage as the original tensor
Run Code Online (Sandbox Code Playgroud)
这很好,但我现在想更改这些值,并且更改也会影响data张量。
我可以用它来做我想要实现的事情,但它似乎非常不Pythonic:
data.gather(1, indices)
tensor([[1],
[5],
[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather
# does not use the same storage as the original tensor
Run Code Online (Sandbox Code Playgroud)
关于如何更优雅地做到这一点有什么提示吗?
您正在寻找的是torch.scatter_带有value选项的东西。
\n\n\n
Tensor.scatter_(dim, index, src, reduce=None) \xe2\x86\x92 Tensor
\n将张量中的所有值src写入张量self中指定的索引处index。对于 中的每个值src,其输出index由 src 中的索引dimension != dim和 的相应值索引指定dimension = dim。使用 2D 张量作为输入 和
\ndim=1,运算为:\nself[i][index[i][j]] = src[i][j]
但没有提到 value 参数......
\n使用value=42、 和dim=1,这将对数据产生以下影响:
data[i][index[i][j]] = 42\nRun Code Online (Sandbox Code Playgroud)\n这里就地应用:
\n>>> data.scatter_(index=indices, dim=1, value=42)\n>>> data\ntensor([[ 0, 42, 2],\n [ 3, 4, 42],\n [ 6, 42, 8]])\nRun Code Online (Sandbox Code Playgroud)\n
| 归档时间: |
|
| 查看次数: |
2256 次 |
| 最近记录: |