如何使用 torch.hub.load 加载本地模型?

cod*_*com 11 python machine-learning torch pytorch torchvision

我需要避免从网上下载模型(由于安装的机器的限制)。

这可行,但它从互联网下载模型

model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True)
Run Code Online (Sandbox Code Playgroud)

我已将.pth文件和hubconf.py文件放在 /tmp/ 文件夹中,并将我的代码更改为

model = torch.hub.load('/tmp/', 'deeplabv3_resnet101', pretrained=True, source='local')
Run Code Online (Sandbox Code Playgroud)

但令我惊讶的是,它仍然从互联网上下载模型。我究竟做错了什么?如何在本地加载模型?

只是为了向您提供更多详细信息,我在 Docker 容器中执行所有这些操作,该容器在运行时具有只读卷,因此这就是新文件下载失败的原因。

小智 8

您可以采用两种方法在没有 Internet 连接的计算机上获取可发布的模型。

  1. 在普通机器上加载带有预训练模型的 DeepLab,使用JIT编译器将其导出为图,然后将其放入机器中。该脚本很容易遵循:

     # To export
     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=True).eval()
     traced_graph = torch.jit.trace(model, torch.randn(1, 3, H, W))
     traced_graph.save('DeepLab.pth')
    
     # To load
     model = torch.jit.load('DeepLab.pth').eval().to(device)
    
    Run Code Online (Sandbox Code Playgroud)

    在这种情况下,权重和网络结构将保存为计算图,因此您不需要任何额外的文件。

  2. 查看torchvision 的 GitHub 存储库

    有一个带有 Resnet101 主干权重的 DeepLabV3 的下载 URL

    您可以下载这些权重一次,然后使用 torchvision 的 deeplab 和pretrained=False标志并手动加载权重。

     model = torch.hub.load('pytorch/vision:v0.9.0', 'deeplabv3_resnet101', pretrained=False)
     model.load_state_dict(torch.load('downloaded weights path'))
    
    Run Code Online (Sandbox Code Playgroud)

    考虑到,状态字典中可能有一个['state_dict']或一些类似的父键,您可以在其中使用:

     model.load_state_dict(torch.load('downloaded weights path')['state_dict'])
    
    Run Code Online (Sandbox Code Playgroud)

  • H 和 W 是什么? (2认同)

小智 6

model_name='best.pt'
model = torch.hub.load(os.getcwd(), 'custom', source='local', path = model_name, force_reload = True)
Run Code Online (Sandbox Code Playgroud)

这对我有用。默认来源是github。