PyTorch [1 if x > 0.5 else 0 for x in output ] 带张量

Bri*_*nto 3 python pytorch torchvision

我有一个 sigmoid 函数的列表输出作为 PyTorch 中的张量

例如

output (type) = torch.Size([4]) tensor([0.4481, 0.4014, 0.5820, 0.2877], device='cuda:0',
Run Code Online (Sandbox Code Playgroud)

在进行二元分类时,我想将所有值从 0.5 变为 0,将高于 0.5 的值变为 1。

传统上,您可以使用 NumPy 数组使用列表迭代器:

output_prediction = [1 if x > 0.5 else 0 for x in outputs ]
Run Code Online (Sandbox Code Playgroud)

这会起作用,但是我稍后必须将 output_prediction 转换回张量才能使用

torch.sum(ouput_prediction == labels.data)
Run Code Online (Sandbox Code Playgroud)

其中labels.data 是标签的二进制张量。

有没有办法将列表迭代器与张量一起使用?

zih*_*hao 17

prob = torch.tensor([0.3,0.4,0.6,0.7])

out = (prob>0.5).float()
# tensor([0.,0.,1.,1.])
Run Code Online (Sandbox Code Playgroud)

说明:在pytorch中,可以直接使用prob>0.5获取torch.bool类型张量。然后你可以通过.float().