Pau*_*ers 3 python deep-learning pytorch
我有以下 PyTorch 模型:
import math
from abc import abstractmethod
import torch.nn as nn
class AlexNet3D(nn.Module):
@abstractmethod
def get_head(self):
pass
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.features = nn.Sequential(
nn.Conv3d(1, 64, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=0),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=3, stride=3),
nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=0),
nn.BatchNorm3d(128),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=3, stride=3),
nn.Conv3d(128, 192, kernel_size=(3, 3, 3), padding=1),
nn.BatchNorm3d(192),
nn.ReLU(inplace=True),
nn.Conv3d(192, 192, kernel_size=(3, 3, 3), padding=1),
nn.BatchNorm3d(192),
nn.ReLU(inplace=True),
nn.Conv3d(192, 128, kernel_size=(3, 3, 3), padding=1),
nn.BatchNorm3d(128),
nn.ReLU(inplace=True),
nn.MaxPool3d(kernel_size=3, stride=3),
)
self.classifier = self.get_head()
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
xp = self.features(x)
x = xp.view(xp.size(0), -1)
x = self.classifier(x)
return [x, xp]
class AlexNet3DDropoutRegression(AlexNet3D):
def get_head(self):
return nn.Sequential(nn.Dropout(),
nn.Linear(self.input_size, 64),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(64, 1),
)
Run Code Online (Sandbox Code Playgroud)
我正在这样初始化模型:
def init_model(self):
model = AlexNet3DDropoutRegression(4608)
if self.use_cuda:
log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model = model.to(self.device)
return model
Run Code Online (Sandbox Code Playgroud)
训练后,我像这样保存模型:
torch.save(self.model.state_dict(), self.cli_args.model_save_location)
Run Code Online (Sandbox Code Playgroud)
然后我尝试加载保存的模型:
import torch
from reprex.models import AlexNet3DDropoutRegression
model_save_location = "/home/feczk001/shared/data/AlexNet/LoesScoring/loes_scoring_01.pt"
model = AlexNet3DDropoutRegression(4608)
model.load_state_dict(torch.load(model_save_location,
map_location='cpu'))
Run Code Online (Sandbox Code Playgroud)
但我收到以下错误:
RuntimeError: Error(s) in loading state_dict for AlexNet3DDropoutRegression:
Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.weight", "features.4.bias", "features.5.weight", "features.5.bias", "features.5.running_mean", "features.5.running_var", "features.8.weight", "features.8.bias", "features.9.weight", "features.9.bias", "features.9.running_mean", "features.9.running_var", "features.11.weight", "features.11.bias", "features.12.weight", "features.12.bias", "features.12.running_mean", "features.12.running_var", "features.14.weight", "features.14.bias", "features.15.weight", "features.15.bias", "features.15.running_mean", "features.15.running_var", "classifier.1.weight", "classifier.1.bias", "classifier.4.weight", "classifier.4.bias".
Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.1.weight", "module.features.1.bias", "module.features.1.running_mean", "module.features.1.running_var", "module.features.1.num_batches_tracked", "module.features.4.weight", "module.features.4.bias", "module.features.5.weight", "module.features.5.bias", "module.features.5.running_mean", "module.features.5.running_var", "module.features.5.num_batches_tracked", "module.features.8.weight", "module.features.8.bias", "module.features.9.weight", "module.features.9.bias", "module.features.9.running_mean", "module.features.9.running_var", "module.features.9.num_batches_tracked", "module.features.11.weight", "module.features.11.bias", "module.features.12.weight", "module.features.12.bias", "module.features.12.running_mean", "module.features.12.running_var", "module.features.12.num_batches_tracked", "module.features.14.weight", "module.features.14.bias", "module.features.15.weight", "module.features.15.bias", "module.features.15.running_mean", "module.features.15.running_var", "module.features.15.num_batches_tracked", "module.classifier.1.weight", "module.classifier.1.bias", "module.classifier.4.weight", "module.classifier.4.bias".
Run Code Online (Sandbox Code Playgroud)
这里出了什么问题?
问题是您使用 训练模型DataParallel,然后尝试在非并行网络中重新加载模型。DataParallel是一个包装类,它使原始模型(对象)成为名为 的对象torch.nn.module的类属性。这个问题已在pytorch 讨论、堆栈溢出和github上得到解决,因此我也不会在这里重复详细信息,但您可以通过以下任一方式解决此问题:DataParallelmodule
仅将模型作为DataParallel对象保存和加载,当您想要使用模型进行推理时,这可能会不再有效,或者
像这样保存DataParallel对象 :module state_dict
# save state dict of DataParallel object
torch.save(model.module.state_dict(), path)
.... Later
# reload weights on non-parallel model
model.load_state_dict(torch.load(path)
Run Code Online (Sandbox Code Playgroud)
这是一个简单的例子:
model = AlexNet3DDropoutRegression(4608) # on cpu
model = nn.DataParallel(model)
model = model.to("cuda") # DataParallel object on GPU(s)
torch.save(model.module.state_dict(),"example_path.pt")
del model
model = AlexNet3DDropoutRegression(4608)
ret = model.load_state_dict(torch.load("example_path.pt"))
print(ret)
Run Code Online (Sandbox Code Playgroud)
输出:
>>> <All keys successfully matched>
Run Code Online (Sandbox Code Playgroud)
state_dict需要重新加载的文件,则可能更有用,您还可以加载模型state_dict的DataParallel文件,重新映射键名称以排除“模块”,然后使用重新键控的state_dict. 就像是:incompatible_state_dict = torch.load("DataParallel_save_file.pt")
state_dict = {}
for key in incompatible_state_dict():
state_dict[key.split("module.")[-1]] = incompatible_state_dict[key]
ret = model.load_state_dict(state_dict)
print(ret)
Run Code Online (Sandbox Code Playgroud)
输出:
>>> <All keys successfully matched>
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
4353 次 |
| 最近记录: |