pytorch 中的一种热门编码

o0o*_*o0o 2 pytorch

我对编码真的很陌生,现在我正在尝试将我的标签变成一种热门编码。我已经完成将 np.array 传输到张量,如下所示

tensor([4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 4., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2.], dtype=torch.float64)
Run Code Online (Sandbox Code Playgroud)

and I am using code to do one hot encoding 

aaa = F.one_hot(torch_qyh, num_classes=5)
Run Code Online (Sandbox Code Playgroud)

但是,出现错误,显示“RuntimeError:one_hot 仅适用于索引张量”。任何帮助将不胜感激。

Dis*_*ani 5

您必须将其转换为long类型。不能用浮动来做到这一点。F.one_hot只需要 LongTensor。

F.one_hot(t.long())
Run Code Online (Sandbox Code Playgroud)