小编God*_* Ho的帖子

在 PyTorch 中加载用于推理的迁移学习模型的正确方法是什么?

我正在使用基于 Resnet152 的迁移学习来训练模型。基于 PyTorch 教程,我在保存经过训练的模型并加载它进行推理方面没有问题。但是,加载模型所需的时间很慢。我不知道我做对了没有,这是我的代码:

将训练好的模型保存为状态字典:

torch.save(model.state_dict(), 'model.pkl')
Run Code Online (Sandbox Code Playgroud)

加载它以进行推理:

model = models.resnet152()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))
st = torch.load('model.pkl', map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(st)
model.eval()
Run Code Online (Sandbox Code Playgroud)

我对代码进行了计时,发现第一行model = models.resnet152()加载时间最长。在 CPU 上,测试一张图像需要 10 秒。所以我的想法是这可能不是加载它的正确方法?

如果我像这样保存整个模型而不是 state.dict:

torch.save(model, 'model_entire.pkl')
Run Code Online (Sandbox Code Playgroud)

并像这样测试它:

model = torch.load('model_entire.pkl')
model.eval()
Run Code Online (Sandbox Code Playgroud)

在同一台机器上,测试一张图像只需 5 秒。

所以我的问题是:这是加载 state_dict 进行推理的正确方法吗?

python python-3.x pytorch

6
推荐指数
1
解决办法
1438
查看次数

标签 统计

python ×1

python-3.x ×1

pytorch ×1