PyTorch - 运行时错误:加载 VGG 的 state_dict 时出错:

Bro*_*983 4 machine-learning deep-learning pytorch

我已经使用 PyTorch 训练了一个模型并保存了一个状态字典文件。我已经使用下面的代码加载了预先训练的模型。我收到有关 RuntimeError: Error(s) in loading state_dict for VGG 的错误消息:

RuntimeError: Error(s) in loading state_dict for VGG:
    Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "classifier.0.weight", "classifier.0.bias", "classifier.3.weight", "classifier.3.bias", "classifier.6.weight", "classifier.6.bias". 
    Unexpected key(s) in state_dict: "state_dict", "optimizer_state_dict", "globalStep", "train_paths", "test_paths". 

Run Code Online (Sandbox Code Playgroud)

我正在遵循此网站上提供的说明:https://pytorch.org/tutorials/beginner/ saving_loading_models.html# saving-loading-model-across-devices

非常感谢

import argparse
import datetime
import glob
import os
import random
import shutil
import time
from os.path import join

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import ToTensor
from tqdm import tqdm
import torch.optim as optim

from convnet3 import Convnet
from dataset2 import CellsDataset

from convnet3 import Convnet
from VGG import VGG
from dataset2 import CellsDataset
from torchvision import models
from Conv import Conv2d

parser = argparse.ArgumentParser('Predicting hits from pixels')
parser.add_argument('name',type=str,help='Name of experiment')
parser.add_argument('data_dir',type=str,help='Path to data directory containing images and gt.csv')
parser.add_argument('--weight_decay',type=float,default=0.0,help='Weight decay coefficient (something like 10^-5)')
parser.add_argument('--lr',type=float,default=0.0001,help='Learning rate')
args = parser.parse_args()

metadata = pd.read_csv(join(args.data_dir,'gt.csv'))
metadata.set_index('filename', inplace=True)

# create datasets:

dataset = CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
dataset = DataLoader(dataset,num_workers=4,pin_memory=True)
model_path = '/Users/nubstech/Documents/GitHub/CellCountingDirectCount/VGG_model_V1/checkpoints/checkpoint.pth'

class VGG(nn.Module):
    def __init__(self, pretrained=True):
        super(VGG, self).__init__()
        vgg = models.vgg16(pretrained=pretrained)
        # if pretrained:
        vgg.load_state_dict(torch.load(model_path))
        features = list(vgg.features.children())
        self.features4 = nn.Sequential(*features[0:23])


        self.de_pred = nn.Sequential(Conv2d(512, 128, 1, same_padding=True, NL='relu'),
                                     Conv2d(128, 1, 1, same_padding=True, NL='relu'))


    def forward(self, x):
        x = self.features4(x)       
        x = self.de_pred(x)

        return x

model=VGG()
#model.load_state_dict(torch.load(model_path),strict=False)
model.eval()        

#optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)

for images, paths in tqdm(dataset):

    targets = torch.tensor([metadata['count'][os.path.split(path)[-1]] for path in paths]) # B
    targets = targets.float()

    # code to print training data to a csv file
    #filename=CellsDataset(args.data_dir,transform=ToTensor(),return_filenames=True)
    output = model(images) # B x 1 x 9 x 9 (analogous to a heatmap)
    preds = output.sum(dim=[1,2,3]) # predicted cell counts (vector of length B)
    print(preds)
    paths_test = np.array([paths])
    names_preds = np.hstack(paths)
    print(names_preds)                
    df=pd.DataFrame({'Image_Name':names_preds, 'Target':targets.detach(), 'Prediction':preds.detach()})
    print(df) 
    # save image name, targets, and predictions
    df.to_csv(r'model.csv', index=False, mode='a')

Run Code Online (Sandbox Code Playgroud)

保存状态字典的代码

        torch.save({'state_dict':model.state_dict(),
                    'optimizer_state_dict':optimizer.state_dict(),
                    'globalStep':global_step,
                    'train_paths':dataset_train.files,
                    'test_paths':dataset_test.files},checkpoint_path)
Run Code Online (Sandbox Code Playgroud)

Ale*_*x I 5

问题是正在保存的内容与预期加载的内容不同。该代码尝试加载一个 state_dict;它节省的东西比这多得多 - 看起来像另一个带有附加信息的字典中的 state_dict 。load 方法没有任何逻辑来查看字典内部。

这应该有效:

import torch, torchvision.models
model = torchvision.models.vgg16()
path = 'test.pth'
torch.save(model.state_dict(), path) # nothing else here
model.load_state_dict(torch.load(path))
Run Code Online (Sandbox Code Playgroud)