pytorch闪电模型的输出预测

Tom*_*m S 10 python pytorch pytorch-lightning

这可能是一个非常简单的问题。我刚刚开始使用 PyTorch Lightning,不知道如何在训练后接收模型的输出。

我对 y_train 和 y_test 的预测感兴趣,作为某种数组(后续步骤中的 PyTorch 张量或 NumPy 数组),以使用不同的脚本在标签旁边绘制。

dataset = Dataset(train_tensor)
val_dataset = Dataset(val_tensor)
training_generator = torch.utils.data.DataLoader(dataset, **train_params)
val_generator = torch.utils.data.DataLoader(val_dataset, **val_params)
mynet = Net(feature_len)
trainer = pl.Trainer(gpus=0,max_epochs=max_epochs, logger=logger, progress_bar_refresh_rate=20, callbacks=[early_stop_callback], num_sanity_val_steps=0)
trainer.fit(mynet)
Run Code Online (Sandbox Code Playgroud)

在我的闪电模块中,我具有以下功能:

def __init__(self, random_inputs):

def forward(self, x):

def train_dataloader(self):
    
def val_dataloader(self):

def training_step(self, batch, batch_nb):

def training_epoch_end(self, outputs):

def validation_step(self, batch, batch_nb):

def validation_epoch_end(self, outputs):

def configure_optimizers(self):
Run Code Online (Sandbox Code Playgroud)

我是否需要特定的预测函数,或者是否有任何我看不到的已经实现的方法?

小智 12

我不同意这些答案:OP的问题似乎集中在他应该如何使用闪电训练的模型来获得一般预测,而不是训练管道中的特定步骤。在这种情况下,用户不需要走到训练器对象附近的任何地方 - 这些并不打算用于一般预测,因此上面的答案鼓励反模式(每次我们都随身携带训练器对象)想要对未来阅读这些答案的任何人做一些预测。

trainer我们可以直接从已定义的闪电模块中获得预测,而不是使用:如果我有闪电模块的(经过训练的)实例model = Net(...),那么使用该模型来获取输入的预测x只需通过调用即可实现model(x)(只要该forward方法已在闪电模块上实现/覆盖 - 这是必需的)。

相反,Trainer.predict()这并不是使用经过训练的模型获得预测的预期方法。Trainer API 提供了和LightningModule 的方法tune,作为训练管道的一部分,在我看来,该方法是为单独的数据加载器上的临时预测提供的,作为不太“标准”训练步骤的一部分。fittestpredict

OP 的问题(我需要一个特定的预测函数还是有任何我没有看到的已经实现的方法?)暗示他们不熟悉该方法forward()在 PyTorch 中的工作方式,但询问是否已经有一种方法他们看不到的预测。因此,完整的答案需要进一步解释该forward()方法在预测过程中的适用位置:

之所以model(x)有效,是因为 Lightning 模块是 的子类torch.nn.Module,并且它们实现了一个称为的神奇方法__call__(),这意味着我们可以像调用函数一样调用类实例。__call__()依次调用forward(),这就是为什么我们需要在 Lightning 模块中重写该方法。

注意。因为forward是 我们使用 时调用的逻辑的一部分model(x),所以总是建议使用model(x)来代替model.forward(x)预测,除非您有特定的原因需要偏离。

  • 很高兴您指出了如何直接运行网络,因为从使用 Pytorch Lightning 开始而无需使用 Pytorch 直接隐藏了底层机制。我认为,在某些情况下使用 Trainer 类(甚至用于预测)仍然是合理的,因为它负责将模型和数据放到 GPU 上,它可以调用某些钩子,为什么要重新发明轮子呢?这不是反模式,将类重命名为“Commander”,您的大部分论点都是无效的。我仍然认为你指出这一点很好,但反模式太强大了。 (4认同)

sus*_*mit 5

predict您也可以使用该方法。这是文档中的示例。https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html

class LitMNISTDreamer(LightningModule):

    def forward(self, z):
        imgs = self.decoder(z)
        return imgs

    def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
        return self(batch)


model = LitMNISTDreamer()
trainer.predict(model, datamodule) 
Run Code Online (Sandbox Code Playgroud)

  • 似乎同时添加了预测方法。我只是感到困惑,以前不可用。 (2认同)