如何使用 PyTorch 沿特定维度进行热编码?

jul*_*iet 3 python one-hot-encoding pytorch tensor

我有一个大小为 的张量[3, 15, 136],其中:

  • 3 is batch size
  • 15 - sequence length
  • 136 is tokens

我想使用维度 (136) 中的概率来独热我的张量tokens。为此,我想提取序列长度中每个字母的标记维度,并将其置于1最大可能性,并将所有其他标记标记为0

uke*_*emi 5

你可以使用PyTorch的one_hot函数来实现这一点:

import torch.nn.functional as F

t = torch.rand(3, 15, 136)

F.one_hot(t.argmax(dim=2), 136)
Run Code Online (Sandbox Code Playgroud)