小编Vvv*_*vvv的帖子

如何将 model.state_dict() 存储在临时变量中以供以后使用?

我尝试将模型的状态字典临时存储在变量中,并希望稍后将其恢复到我的模型中,但该变量的内容会随着模型更新而自动更改。

有一个最小的例子:

import torch as t
import torch.nn as nn
from torch.optim import Adam


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, x):
        return self.fc(x)


net = Net()
loss_fc = nn.MSELoss()
optimizer = Adam(net.parameters())

weights = net.state_dict()
print(weights)

x = t.rand((5, 3))
y = t.rand((5, 2))
loss = loss_fc(net(x), y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(weights)
Run Code Online (Sandbox Code Playgroud)

我认为两个输出是相同的,但我得到了(输出可能由于随机初始化而改变)

OrderedDict([('fc.weight', tensor([[-0.5557,  0.0544, -0.2277],
        [-0.0793,  0.4334, -0.1548]])), ('fc.bias', tensor([-0.2204,  0.2846]))])
OrderedDict([('fc.weight', tensor([[-0.5547,  0.0554, -0.2267],
        [-0.0783,  0.4344, -0.1538]])), ('fc.bias', tensor([-0.2194,  0.2856]))])
Run Code Online (Sandbox Code Playgroud)

内容 …

python pytorch

2
推荐指数
1
解决办法
1908
查看次数

标签 统计

python ×1

pytorch ×1