Pytorch:在特定张量维度中设置索引(类似于torch.index_select)

Ero*_*mic 5 python pytorch

我正在尝试获取和设置特定张量维度中的索引,如果可能的话,无需重新整形。我已经能够找到torch.index_select在获取值时执行我想要的操作的函数,但我还没有找到类似的设置函数。有吗?

对于上下文,我有一个张量和一组索引

class_energy = torch.rand(3, 10, 32, 32)
class_logits = torch.empty_like(class_energy)
idxs = [2, 3, 5, 7]
Run Code Online (Sandbox Code Playgroud)

我想访问特定维度中这些索引处的项目,以便我可以执行 log_softmax。

如果我知道暗淡的先验,那么我可以简单地使用花哨的__getitem__/__setitem__语法:例如, if dim=1, then class_energy[:, idxs]。同样,如果dim=2-> class_energy[:, :, idxs]dim=0->class_energy[idxs]等...

在 的情况下dim=1,我本质上想要这样:

class_logits[:, idxs] = F.log_softmax(class_energy[:, idxs], dim=1)
Run Code Online (Sandbox Code Playgroud)

dim不幸的是,我不知道提前的价值。当然,我可以通过以下方式提前建立精美的索引:

fancy_index = tuple([slice(None)] * dim + [idxs])
class_logits[fancy_index] = F.log_softmax(class_energy[fancy_index], dim=dim)
Run Code Online (Sandbox Code Playgroud)

但是,我想知道是否有更好的方法来做到这一点。对于 的情况__getitem__,我知道确实存在。下面的代码使用torch.index_select是等效的

fancy_index = tuple([slice(None)] * dim + [idxs])
index = torch.LongTensor(idxs).to(class_energy.device)
class_logits[fancy_index] = F.log_softmax(torch.index_select(class_energy, dim=dim, index=index, dim=dim))
Run Code Online (Sandbox Code Playgroud)

它不仅在功能上相同,而且 index_select 比使用花哨的 getitem 语法要快得多(我已经看到了 2 倍的改进)。

我的问题是,我似乎无法为__setitem__代码部分提供类似的功能。如果我能一起摆脱花哨的索引,那就太好了。我研究过Tensor.put_torch.index_puttorch.select,但这些似乎都不具备我想要的功能。有什么我遗漏的吗,或者花哨的索引是目前解决这个问题的唯一方法吗?