如何加载和使用 PyTorch (.pth.tar) 模型

11 python neural-network python-3.x torch pytorch

我对Torch不是很熟悉,主要使用Tensorflow。然而,我需要使用在 Torch 中重新训练的重新训练的初始模型。由于为我的特定应用程序重新训练初始模型需要大量计算资源,因此我想使用已经重新训练的模型。

该模型保存为.pth.tar文件。

我希望能够首先加载这个模型。到目前为止,我已经知道我必须使用以下内容:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')
Run Code Online (Sandbox Code Playgroud)

这似乎有效,因为print(model)打印出一大组数字和其他值,我认为它们是权重和偏差的值。

之后,我需要能够用它对图像进行分类。我一直无法弄清楚这一点。我必须如何格式化图像?是否应该将图像转换为数组?之后,我必须如何将输入数据传递到网络?

cle*_*ros 13

您基本上需要执行与张量流中相同的操作。也就是说,当您存储网络时,只会存储参数(即网络中的可训练对象),而不存储“粘合剂”,这就是使用经过训练的模型所需的所有逻辑。因此,如果您有.pth.tar文件,则可以加载它,从而覆盖已定义的模型的参数值。

这意味着保存/加载模型的一般过程如下:

  • 写下你的网络定义(即你的nn.Module对象)
  • 以您想要的方式训练或以其他方式更改网络参数
  • 使用保存参数torch.save
  • 当您想使用该网络时,请使用相同的nn.Module对象定义首先实例化 pytorch 网络
  • 然后使用覆盖网络参数的值torch.load

以下是有关如何执行此操作的一些参考文献的讨论:pytorch 论坛

这是一个超短的 mwe:

# to store
torch.save({
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')

# to load
checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
Run Code Online (Sandbox Code Playgroud)

  • OP从未批准你的答案,而且你无论如何都遵守了我的问题......所以这不是劫持。将其视为澄清请求。感谢您的帮助,它澄清了事情。 (3认同)