nn.CrossEntropyLoss()的Pytorch输入

Ama*_*bek 6 logistic-regression pytorch

我试图在PyTorch中对简单的0,1标记的数据集执行Logistic回归。标准或损失定义为:criterion = nn.CrossEntropyLoss()。该模型是:model = LogisticRegression(1,2)

我有一个成对的数据点:dat = (-3.5, 0),第一个元素是数据点,第二个元素是相应的标签。
然后,将输入的第一个元素转换为张量:tensor_input = torch.Tensor([dat[0]])
然后我将该模型应用到tensor_input: outputs = model(tensor_input)
然后,将标签转换为张量:tensor_label = torch.Tensor([dat[1]])
现在,当我尝试执行此操作时,事情就中断了:loss = criterion(outputs, tensor_label)。它给出和错误:RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

import torch
import torch.nn as nn

class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes) 

    def forward(self, x):
        out = self.linear(x)
        return out

model = LogisticRegression(1,2)
criterion = nn.CrossEntropyLoss()
dat = (-3.5,0)
tensor_input = torch.Tensor([dat[0]])
outputs = binary_model(tensor_input)
tensor_label = torch.Tensor([dat[1]])
loss = criterion(outputs, tensor_label)
Run Code Online (Sandbox Code Playgroud)

我无法为自己的生活弄清楚。

den*_*ger 8

在大多数情况下,PyTorch 文档在解释不同功能方面做得非常出色;它们通常包括预期的输入维度,以及一些简单的例子。
您可以在nn.CrossEntropyLoss() 此处找到说明。

为了完成您的具体示例,让我们首先查看预期的输入维度:

输入:(N,C) 其中 C = 类数。[...]

除此之外,N通常是指批量大小(样本数)。将此与您目前拥有的进行比较:

outputs.shape
>>> torch.Size([2])
Run Code Online (Sandbox Code Playgroud)

即目前我们只有(2,),而不是(1,2)PyTorch 预期的输入维度。我们可以通过向我们当前的张量添加一个“假”维度来缓解这个问题,只需.unsqueeze()像这样使用:

outputs = binary_model(tensor_input).unsqueeze(dim=0)
outputs.shape
>>> torch.Size([1,2])
Run Code Online (Sandbox Code Playgroud)

现在我们知道了,让我们看看目标的预期输入:

目标:(N) [...]

所以我们已经得到了正确的形状。但是,如果我们尝试这样做,我们仍然会遇到错误:

RuntimeError: Expected object of scalar type Long but got scalar type Float 
              for argument #2 'target'.
Run Code Online (Sandbox Code Playgroud)

同样,错误消息相当有表现力。这里的问题是 PyTorch 张量(默认情况下)被解释为torch.FloatTensors,但输入应该是整数(或Long)。我们可以通过在张量创建期间指定确切类型来简单地做到这一点:

tensor_label = torch.LongTensor([dat[1]])
Run Code Online (Sandbox Code Playgroud)

我在 Linux 下使用 PyTorch 1.0 仅供参考。