如何在 pytorch 中加载模型而不必记住使用的参数?

Mic*_*ael 3 python pytorch

我正在 pytorch 中训练一个模型,我为此创建了一个类,如下所示:

from torch import nn

class myNN(nn.Module):
    def __init__(self, dense1=128, dense2=64, dense3=32, ...):
         self.MLP = nn.Sequential(
            nn.Linear(dense1, dense2),
            nn.ReLU(),
            nn.Linear(dense2, dense3),
            nn.ReLU(),
            nn.Linear(dense3, 1)
         )
         ...
Run Code Online (Sandbox Code Playgroud)

为了保存它,我正在使用:

torch.save(model.state_dict(), checkpoint_model_path)
Run Code Online (Sandbox Code Playgroud)

并加载它我正在使用:

model = myNN()   # or with specified parameters
model.load_state_dict(torch.load(model_file))
Run Code Online (Sandbox Code Playgroud)

但是,为了使此方法起作用,我必须在 myNN() 的构造函数中使用正确的值。这意味着我需要以某种方式记住或存储我在每种情况下使用的参数(层大小),以便正确加载不同的模型。

有没有一种灵活的方法可以在 pytorch 中保存/加载模型,我还可以读取层的大小?

例如,通过直接加载 myNN() 对象或以某种方式从保存的 pickle 文件中读取层大小?

我犹豫是否要尝试在 PyTorch 中保存训练模型的最佳方式中的第二种方法?由于那里提到的警告。有更好的方法来实现我想要的吗?

Iva*_*van 11

事实上,序列化整个 Python 是一个相当激进的举动。相反,您始终可以在保存的文件中添加用户定义的项目:您可以保存模型的状态及其类参数。像这样的事情会起作用:

  1. 首先将参数保存在实例中,以便我们可以在保存模型时序列化它们:

    class myNN(nn.Module):
        def __init__(self, dense1=128, dense2=64, dense3=32):
            super().__init__()
            self.kwargs = {'dense1': dense1, 'dense2': dense2, 'dense3': dense3}
            self.MLP = nn.Sequential(
                nn.Linear(dense1, dense2),
                nn.ReLU(),
                nn.Linear(dense2, dense3),
                nn.ReLU(),
                nn.Linear(dense3, 1))
    
    Run Code Online (Sandbox Code Playgroud)
  2. 我们可以保存模型的参数及其初始化参数:

    >>> torch.save([model.kwargs, model.state_dict()], path)
    
    Run Code Online (Sandbox Code Playgroud)
  3. 然后加载它:

    >>> kwargs, state = torch.load(path)
    >>> model = myNN(**kwargs)
    >>> model.load_state_dict(state)
    <All keys matched successfully>
    
    Run Code Online (Sandbox Code Playgroud)