加载预训练模型pytorch - dict对象没有属性eval

nir*_*air 9 python deep-learning conv-neural-network pytorch

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best)
Run Code Online (Sandbox Code Playgroud)

我正在这样保存我的模型。如何加载模型,以便我可以在其他地方使用它,例如 cnn 可视化?

这就是我现在加载模型的方式:

torch.load('model_best.pth.tar')
Run Code Online (Sandbox Code Playgroud)

但是当我这样做时,我收到此错误:

AttributeError:“dict”对象没有属性“eval”

我在这里缺少什么???

编辑:我想使用我训练的模型来可视化过滤器和梯度。我正在使用这个repo来制作 vis。我将第 179 行替换为torch.load('model_best.pth.tar')

Sal*_*goz 5

首先,您已经说明了您的模型。torch.load() 为您提供了一本字典。该字典没有 eval 函数。所以你应该将权重上传到你的模型中。

import torch
from modelfolder import yourmodel

model = yourmodel()
checkpoint = torch.load('model_best.pth.tar')
try:
    checkpoint.eval()
except AttributeError as error:
    print error
### 'dict' object has no attribute 'eval'

model.load_state_dict(checkpoint['state_dict'])
### now you can evaluate it
model.eval()
Run Code Online (Sandbox Code Playgroud)