我正在尝试获取和设置特定张量维度中的索引,如果可能的话,无需重新整形。我已经能够找到torch.index_select
在获取值时执行我想要的操作的函数,但我还没有找到类似的设置函数。有吗?
对于上下文,我有一个张量和一组索引
class_energy = torch.rand(3, 10, 32, 32)
class_logits = torch.empty_like(class_energy)
idxs = [2, 3, 5, 7]
Run Code Online (Sandbox Code Playgroud)
我想访问特定维度中这些索引处的项目,以便我可以执行 log_softmax。
如果我知道暗淡的先验,那么我可以简单地使用花哨的__getitem__
/__setitem__
语法:例如, if dim=1
, then class_energy[:, idxs]
。同样,如果dim=2
-> class_energy[:, :, idxs]
、dim=0
->class_energy[idxs]
等...
在 的情况下dim=1
,我本质上想要这样:
class_logits[:, idxs] = F.log_softmax(class_energy[:, idxs], dim=1)
Run Code Online (Sandbox Code Playgroud)
dim
不幸的是,我不知道提前的价值。当然,我可以通过以下方式提前建立精美的索引:
fancy_index = tuple([slice(None)] * dim + [idxs])
class_logits[fancy_index] = F.log_softmax(class_energy[fancy_index], dim=dim)
Run Code Online (Sandbox Code Playgroud)
但是,我想知道是否有更好的方法来做到这一点。对于 的情况__getitem__
,我知道确实存在。下面的代码使用torch.index_select
是等效的
fancy_index = tuple([slice(None)] * dim + [idxs])
index = torch.LongTensor(idxs).to(class_energy.device)
class_logits[fancy_index] = F.log_softmax(torch.index_select(class_energy, dim=dim, index=index, dim=dim))
Run Code Online (Sandbox Code Playgroud)
它不仅在功能上相同,而且 index_select 比使用花哨的 getitem 语法要快得多(我已经看到了 2 倍的改进)。
我的问题是,我似乎无法为__setitem__
代码部分提供类似的功能。如果我能一起摆脱花哨的索引,那就太好了。我研究过Tensor.put_
、torch.index_put
和torch.select
,但这些似乎都不具备我想要的功能。有什么我遗漏的吗,或者花哨的索引是目前解决这个问题的唯一方法吗?
归档时间: |
|
查看次数: |
2464 次 |
最近记录: |