PyTorch 中的运行损失是什么以及它是如何计算的

Jit*_*ddi 9 python deep-learning torch pytorch torchvision

我在 PyTorch 文档中查看了教程以了解迁移学习。有一句我没看懂。

使用 计算损失后,使用 计算loss = criterion(outputs, labels)运行损失running_loss += loss.item() * inputs.size(0),最后使用 计算epoch 损失running_loss / dataset_sizes[phase]

loss.item()应该用于整个小批量(如果我错了,请纠正我)。即,如果batch_size是 4,loss.item()将给出整个 4 张图像集的损失。如果这是真的,为什么在计算时loss.item()乘以?在这种情况下,这一步不是像一个额外的乘法吗?inputs.size(0)running_loss

任何帮助,将不胜感激。谢谢!

kHa*_*hit 13

这是因为由CrossEntropy或其他损失函数给出的损失除以元素的数量,即mean默认情况下减少参数。

torch.nn.CrossEntropyLoss(权重=无,size_average=None,ignore_index=-100,reduce=None,reduction='mean')

因此,loss.item()包含整个小批量的损失,但除以批量大小。这就是为什么在计算 时loss.item()乘以批次大小,由 给出。inputs.size(0)running_loss


Piy*_*ngh 6

如果 batch_size 为 4,则 loss.item() 将给出整个 4 张图像集的损失

这取决于如何loss计算。请记住,loss是一个张量,就像其他张量一样。通常,PyTorch API默认返回平均损失

“损失是每个小批量观察结果的平均值。”

t.item()对于张量,t只需将其转换为 python 的默认 float32。

更重要的是,如果您是 PyTorch 的新手,了解我们用于t.item()维持运行损失而不是t因为 PyTorch 张量存储其值的历史记录可能会很快使您的 GPU 过载,这可能对您有所帮助。

  • `t.item()` 检索驻留在 CPU 上的损失值(作为浮点数),但要点是 `t` 包含上述损失值和梯度,后者仅与反向传播相关,因此应该计算/存储运​​行损失时不保留。因此,无论您使用的是 CPU 还是 GPU,这都是相关的。 (2认同)