Why*_*rch 3 python machine-learning pytorch
我正在尝试使用 torch.load 加载预训练模型。
我收到以下错误:
ModuleNotFoundError: No module named 'utils'
Run Code Online (Sandbox Code Playgroud)
我通过从命令行打开它来检查我使用的路径是否正确。可能是什么原因造成的?
这是我的代码:
import torch
import sys
PATH = './gan.pth'
model = torch.load(PATH)
model.eval()
Run Code Online (Sandbox Code Playgroud)
编辑:整个错误堆栈:
Traceback (most recent call last):
File "load.py", line 6, in <module>
model = torch.load(PATH)
File "C:\Users\user\anaconda3\envs\pytorch-flask\lib\site-packages\torch\serialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "C:\Users\user\anaconda3\envs\pytorch-flask\lib\site-packages\torch\serialization.py", line 774, in _legacy_load
result = unpickler.load()
ModuleNotFoundError: No module named 'utils'
Run Code Online (Sandbox Code Playgroud)
编辑此答案不提供问题的答案,但解决了给定代码中的另一个问题
该.pth
文件仅存储模型的参数,而不存储模型本身。当您想要加载模型时,您将需要.pt/-h
模型类的文件和 python 代码。然后你可以像这样加载它:
# your model
class YourModel(nn.Modules):
def __init__(self):
super(YourModel, self).__init__()
. . .
def forward(self, x):
. . .
# the pytorch save-file in which you stored your trained model
model_file = "<your path>"
model = Model()
model = model.load_state_dict(torch.load(model_file))
model.eval()
Run Code Online (Sandbox Code Playgroud)
小智 5
我遇到了同样的错误,并且想知道问题是什么。原来问题是torch.load()
需要模块保存的数据utils
。
例子:
from utils import some_function
model = some_function()
torch.save(model)
Run Code Online (Sandbox Code Playgroud)
在给定示例中使用 torch 保存时,它会识别出使用模块 utils 来获取所需的数据。因此,在加载“.pth”文件时,您需要导入相同的模块utils
。
归档时间: |
|
查看次数: |
17630 次 |
最近记录: |