如何在 PyTorch 中更新神经网络的参数?

the*_*ass 1 pytorch

比方说,我想乘在神经网络中的所有参数PyTorch(从继承类的实例,torch.nn.Module通过)0.9。我该怎么做?

the*_*ass 5

net一个你的神经网络类的实例。然后你可以做

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    state_dict[name].copy_(transformed_param)
Run Code Online (Sandbox Code Playgroud)

将所有参数乘以0.9

如果您只想更新权重而不是所有参数,您可以这样做

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Don't update if this is not a weight.
    if not "weight" in name:
        continue

    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    state_dict[name].copy_(transformed_param)
Run Code Online (Sandbox Code Playgroud)