小编Eli*_*ias的帖子

pytorch nn.CrossEntropyLoss()中的交叉熵损失

也许有人能够在这里帮助我.我正在尝试计算网络给定输出的交叉熵损失

print output
Variable containing:
1.00000e-02 *
-2.2739  2.9964 -7.8353  7.4667  4.6921  0.1391  0.6118  5.2227  6.2540     
-7.3584
[torch.FloatTensor of size 1x10]
Run Code Online (Sandbox Code Playgroud)

和所需的标签,形式

print lab
Variable containing:
x
[torch.FloatTensor of size 1]
Run Code Online (Sandbox Code Playgroud)

其中x是0到9之间的整数.根据pytorch文档(http://pytorch.org/docs/master/nn.html)

criterion = nn.CrossEntropyLoss()
loss = criterion(output, lab)
Run Code Online (Sandbox Code Playgroud)

这应该工作,但不幸的是我得到一个奇怪的错误

TypeError: FloatClassNLLCriterion_updateOutput received an invalid combination of arguments - got (int, torch.FloatTensor, !torch.FloatTensor!, torch.FloatTensor, bool, NoneType, torch.FloatTensor, int), but expected (int state, torch.FloatTensor input, torch.LongTensor target, torch.FloatTensor output, bool sizeAverage, [torch.FloatTensor weights or None], torch.FloatTensor total_weight, int ignore_index) …
Run Code Online (Sandbox Code Playgroud)

entropy loss torch cross-entropy pytorch

6
推荐指数
1
解决办法
9820
查看次数

标签 统计

cross-entropy ×1

entropy ×1

loss ×1

pytorch ×1

torch ×1