Yin*_*Gao 3 python tensorflow pytorch
我使用 MNIST 数据集训练了一个 CNN 模型,现在想要预测图像的分类,其中包含数字 3。
但是当我尝试使用这个 CNN 进行预测时,pytorch 给了我这个错误:
TypeError: 'collections.OrderedDict' object is not callable
Run Code Online (Sandbox Code Playgroud)
这就是我写的:
cnn = torch.load("/usr/prakt/w153/Desktop/score_detector.pkl")
img = scipy.ndimage.imread("/usr/prakt/w153/Desktop/resize_num_three.png")
test_x = Variable(torch.unsqueeze(torch.FloatTensor(img), dim=1), volatile=True).type(torch.FloatTensor).cuda()
test_output, last_layer = cnn(test_x)
pred = torch.max(test_output, 1)[1].cuda().data.squeeze()
print(pred)
Run Code Online (Sandbox Code Playgroud)
这里有一些解释:
img要预测的图像大小为 28*28score_detector.pkl是经过训练的 CNN 模型
任何帮助将不胜感激!
实际上,您正在加载 state_dict 而不是模型本身。
保存模型如下:
torch.save(model.state_dict(), 'model_state.pth')
Run Code Online (Sandbox Code Playgroud)
而要加载模型状态,您首先需要初始化模型,然后加载状态
model = Model()
model.load_state_dict(torch.load('model_state.pth'))
Run Code Online (Sandbox Code Playgroud)
如果您在 GPU 上训练模型,但想在没有 CUDA 的笔记本电脑上加载模型,那么您需要再添加一个参数
model.load_state_dict(torch.load('model_state.pth', map_location='cpu'))
Run Code Online (Sandbox Code Playgroud)
我很确定score_detector.pkl实际上是一个 state_dict 而不是模型本身。您需要首先实例化模型,然后加载 state_dict,因此您的第一行应替换为如下内容:
cnn = MyModel()
cnn.load_state_dict("/usr/prakt/w153/Desktop/score_detector.pkl")
Run Code Online (Sandbox Code Playgroud)
然后剩下的就应该可以了。请参阅此链接了解更多信息。
| 归档时间: |
|
| 查看次数: |
9742 次 |
| 最近记录: |