检查pytorch中state_dict的参数数量

Aer*_*rin 5 pytorch

SO 有一个关于如何检查模型参数总数的答案: pytorch_total_params = sum(p.numel() for p in model.parameters())

但是,如何检查 中的参数总数state_dict

state_dict = torch.load(model_path, map_location='cpu')

Sha*_*hai 7

您可以计算 state_dict 中保存的条目数:

sum(p.numel() for p in state_dict.values())
Run Code Online (Sandbox Code Playgroud)

然而,这里有一个障碍:state_dict 存储参数 持久缓冲区(例如,BatchNorm 的运行平均值和变量)。没有办法(据我所知)将它们与 state_dict 本身区分开来,您需要将它们加载到模型中并用于sum(p.numel() for p in model.parameters()仅计算参数。

例如,如果您结账resnet50

from torchvision.models import resnet50
model = resnet50(pretrained=True)
state_dict = torch.load('~/.torch/models/resnet50-19c8e357.pth')
num_parameters = sum(p.numel() for p in model.parameters())
num_state_dict = sum(p.numel() for p in state_dict.values())
print('num parameters = {}, stored in state_dict = {}, diff = {}'.format(num_parameters, num_state_dict, num_state_dict - num_parameters))
Run Code Online (Sandbox Code Playgroud)

结果是

num parameters = 25557032, stored in state_dict = 25610152, diff = 53120
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,两个值之间可能存在很大差距。