Luc*_*ter 6 python indexing pytorch tensor
我有一个形状为 n x m 的 2D pytorch 张量。我想使用索引列表(可以使用 torch.gather 完成)对第二个维度进行索引,然后还将新值设置为索引的结果。
例子:
data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
Run Code Online (Sandbox Code Playgroud)
我想选择每行指定的索引(但随后[1,5,7]
也将这些值设置为另一个数字 - 例如 42
我可以通过执行以下操作来逐行选择所需的列:
data.gather(1, indices)
tensor([[1],
[5],
[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather
# does not use the same storage as the original tensor
Run Code Online (Sandbox Code Playgroud)
这很好,但我现在想更改这些值,并且更改也会影响data
张量。
我可以用它来做我想要实现的事情,但它似乎非常不Pythonic:
data.gather(1, indices)
tensor([[1],
[5],
[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather
# does not use the same storage as the original tensor
Run Code Online (Sandbox Code Playgroud)
关于如何更优雅地做到这一点有什么提示吗?
您正在寻找的是torch.scatter_
带有value
选项的东西。
\n\n\n
Tensor.scatter_(dim, index, src, reduce=None) \xe2\x86\x92 Tensor
\n将张量中的所有值src
写入张量self
中指定的索引处index
。对于 中的每个值src
,其输出index
由 src 中的索引dimension != dim
和 的相应值索引指定dimension = dim
。使用 2D 张量作为输入 和
\ndim=1
,运算为:\nself[i][index[i][j]] = src[i][j]
但没有提到 value 参数......
\n使用value=42
、 和dim=1
,这将对数据产生以下影响:
data[i][index[i][j]] = 42\n
Run Code Online (Sandbox Code Playgroud)\n这里就地应用:
\n>>> data.scatter_(index=indices, dim=1, value=42)\n>>> data\ntensor([[ 0, 42, 2],\n [ 3, 4, 42],\n [ 6, 42, 8]])\n
Run Code Online (Sandbox Code Playgroud)\n
归档时间: |
|
查看次数: |
2256 次 |
最近记录: |