SO 有一个关于如何检查模型参数总数的答案:
pytorch_total_params = sum(p.numel() for p in model.parameters())
但是,如何检查 中的参数总数state_dict?
state_dict = torch.load(model_path, map_location='cpu')?
您可以计算 state_dict 中保存的条目数:
sum(p.numel() for p in state_dict.values())
Run Code Online (Sandbox Code Playgroud)
然而,这里有一个障碍:state_dict 存储参数 和 持久缓冲区(例如,BatchNorm 的运行平均值和变量)。没有办法(据我所知)将它们与 state_dict 本身区分开来,您需要将它们加载到模型中并用于sum(p.numel() for p in model.parameters()仅计算参数。
例如,如果您结账resnet50
from torchvision.models import resnet50
model = resnet50(pretrained=True)
state_dict = torch.load('~/.torch/models/resnet50-19c8e357.pth')
num_parameters = sum(p.numel() for p in model.parameters())
num_state_dict = sum(p.numel() for p in state_dict.values())
print('num parameters = {}, stored in state_dict = {}, diff = {}'.format(num_parameters, num_state_dict, num_state_dict - num_parameters))
Run Code Online (Sandbox Code Playgroud)
结果是
Run Code Online (Sandbox Code Playgroud)num parameters = 25557032, stored in state_dict = 25610152, diff = 53120
正如您所看到的,两个值之间可能存在很大差距。
| 归档时间: |
|
| 查看次数: |
2843 次 |
| 最近记录: |