将张量转换为索引的一个热编码张量

Rya*_*an 2 one-hot-encoding pytorch

我有形状为(1,1,128,128,128)的标签张量,其中值的范围可能为0.24。我想使用nn.fucntional.one_hot函数将其转换为一个热编码张量

n = 24
one_hot = torch.nn.functional.one_hot(indices, n)
Run Code Online (Sandbox Code Playgroud)

但这确实需要一个指数张量,老实说,我不确定如何获得这些指数。我唯一的张量是上述形状的标签张量,它包含1-24范围内的值,而不是索引

如何从张量中获取索引张量?提前致谢。

Ber*_*iel 5

如果您得到的错误是此错误:

Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
RuntimeError: one_hot is only applicable to index tensor.
Run Code Online (Sandbox Code Playgroud)

也许您只需要转换为int64

import torch

# random Tensor with the shape you said
indices = torch.Tensor(1, 1, 128, 128, 128).random_(1, 24)
# indices.shape => torch.Size([1, 1, 128, 128, 128])
# indices.dtype => torch.float32

n = 24
one_hot = torch.nn.functional.one_hot(indices.to(torch.int64), n)
# one_hot.shape => torch.Size([1, 1, 128, 128, 128, 24])
# one_hot.dtype => torch.int64
Run Code Online (Sandbox Code Playgroud)

您也可以使用indices.long()

  • 请注意,您需要转换为的整数的精度确实很重要。如果您转换为 int8 或 int16 或 int32 等(可能认为 64 位对于类索引来说太多),您仍然会收到错误。 (3认同)