我怎样才能用pytorch更新网络中的某些特定张量?

gas*_*oon 3 machine-learning image-processing deep-learning conv-neural-network pytorch

例如,我想只在前10个时期更新Resnet中的所有cnn权重并冻结其他时期.
从第11个时代开始,我想改变整个模型.
我怎样才能实现目标?

Sha*_*hai 6

您可以为每个参数组设置学习速率(以及一些其他元参数).您只需根据需要对参数进行分组.
例如,为conv层设置不同的学习率:

import torch
import itertools
from torch import nn

conv_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                             if isinstance(m, nn.Conv2d)])
other_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                              if not isinstance(m, nn.Conv2d)]) 
optimizer = torch.optim.SGD([{'params': other_params},
                             {'params': conv_params, 'lr': 0}],  # set init lr to 0
                            lr=lr_for_model)
Run Code Online (Sandbox Code Playgroud)

您可以稍后访问优化程序param_groups并修改学习速率.

有关详细信息,请参阅每个参数选项.