如何获得 PyTorch Lightning 的总体测试精度?

Ger*_*nuk 6 pytorch-lightning

如何trainer.test使用该方法获得所有批次的总准确度?

我知道我可以实施,model.test_step但这只是针对单个批次。我需要整个数据集的准确性。我可以用来torchmetrics.Accuracy积累准确性。但是将其结合起来并获得总准确度的正确方法是什么?model.test_step由于批量测试分数不是很有用,无论如何应该返回什么?我可以以某种方式破解它,但令我惊讶的是,我在互联网上找不到任何示例来演示如何使用 pytorch-lightning 原生方式获得准确性。

Mik*_*e B 3

您可以在这里看到(https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#automatic-logging),参数on_epochlog纪元结束时自动累积并记录。正确的做法是:

from torchmetrics import Accuracy

def validation_step(self, batch, batch_idx): 
    x, y = batch 
    preds = self.forward(x) 
    loss = self.criterion(preds, y) 
    accuracy = Accuracy()
    acc = accuracy(preds, y)
    self.log('accuracy', acc, on_epoch=True)
    return loss 
Run Code Online (Sandbox Code Playgroud)

如果您想要自定义缩减函数,可以使用reduce_fx参数进行设置,默认值为torch.mean()log()可以从你的任何方法调用LightningModule