如何在 Pytorch 中将 one-hot 向量转换为标签索引并返回?

Gul*_*zar 12 python one-hot-encoding multiclass-classification pytorch

如何将标签向量转换为单热编码并返回到 Pytorch 中?

在经历整个论坛讨论后,问题的解决方案被复制到这里,而不是仅仅通过谷歌搜索找到一个简单的解决方案。

Gul*_*zar 19

来自Pytorch 论坛

import torch
import numpy as np


labels = torch.randint(0, 10, (10,))

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(labels)
# one-hot --> labels
labels_again = torch.argmax(one_hot, dim=1)

np.testing.assert_equals(labels.numpy(), labels_again.numpy())
Run Code Online (Sandbox Code Playgroud)

  • 对于必须指定类数量的情况,请注意[此答案](/sf/answers/5234890531/)[大多数情况下] (2认同)

小智 9

由于我无法评论已接受的答案,我只是想补充一点,如果您的目标不包括所有类别(例如,因为您分批训练),您可以指定类别数量作为参数:

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(target, num_classes=7)
Run Code Online (Sandbox Code Playgroud)