如何在PyTorch Lightning中按每个纪元从记录器中提取损失和准确性?

Wak*_*ame 9 logging tensorboard pytorch pytorch-lightning

我想提取所有数据来绘制绘图,而不是使用张量板。我的理解是,自从张量板绘制线图以来,所有带有损失和准确性的日志都存储在定义的目录中。

%reload_ext tensorboard
%tensorboard --logdir lightning_logs/
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

但是,我想知道如何从 pytorch Lightning 中的记录器中提取所有日志。接下来是训练部分的代码示例。

#model
ssl_classifier = SSLImageClassifier(lr=lr)

#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')

trainer = pl.Trainer(progress_bar_refresh_rate=20,
                            gpus=1,
                            max_epochs = max_epoch,
                            logger = logger,
                            )

trainer.fit(ssl_classifier, train_loader, val_loader)
Run Code Online (Sandbox Code Playgroud)

我已经确认trainer.logger.log_dir返回的目录似乎保存日志并trainer.logger.log_metrics返回<bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>

trainer.logged_metrics仅返回最后一个纪元的日志,例如

{'epoch': 19,
 'train_acc': tensor(1.),
 'train_loss': tensor(0.1038),
 'val_acc': 0.6499999761581421,
 'val_loss': 1.2171183824539185}
Run Code Online (Sandbox Code Playgroud)

你知道如何解决这个情况吗?

Aya*_*Das 3

闪电网络本身并不存储所有日志。它所做的就是将它们流式传输logger到实例中,然后记录器决定要做什么。

检索所有记录的指标的最佳方法是使用自定义回调:

class MetricTracker(Callback):

  def __init__(self):
    self.collection = []

  def on_validation_batch_end(trainer, module, outputs, ...):
    vacc = outputs['val_acc'] # you can access them here
    self.collection.append(vacc) # track them

  def on_validation_epoch_end(trainer, module):
    elogs = trainer.logged_metrics # access it here
    self.collection.append(elogs)
    # do whatever is needed
Run Code Online (Sandbox Code Playgroud)

然后您可以从回调实例访问所有记录的内容

cb = MetricTracker()
Trainer(callbacks=[cb])

cb.collection # do you plotting and stuff
Run Code Online (Sandbox Code Playgroud)