PyTorch中的register_parameter和register_buffer有什么区别?

vai*_*ijr 3 pytorch

参数在训练过程中被更改,也就是说,它们是在神经网络训练过程中学习到的东西,但是什么是缓冲区?

在神经网络训练中是学到的吗?

pro*_*sti 17

您为模块 ( nn.Module)创建的参数和缓冲区。

假设你有一个线性层nn.Linear。你已经有了weightbias参数。但是如果你需要一个新参数,你可以register_parameter()用来注册一个新的命名参数,它是一个张量。

当您注册一个新参数时,它会出现在module.parameters()迭代器中,但是当您注册一个缓冲区时,它不会出现。

区别:

缓冲区被命名为张量,它不会像参数一样在每一步更新梯度。对于缓冲区,您可以创建自定义逻辑(完全由您决定)。

好消息是,当您保存模型时,所有参数和缓冲区都会被保存,并且当您将模型移入或移出 CUDA 时,参数和缓冲区也会随之消失。


Sha*_*hai 5

Pytorch doc用于register_buffer()方法读取

通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm running_mean不是参数,而是持久状态的一部分。

如您所见,在训练过程中使用SGD学习并更新了模型参数
但是,有时还有其他数量属于模型“状态”的一部分,应
另存为state_dict
-移至模型的其余参数cuda()cpu()与之一起使用。
-转换成float/ half/ double与模型的参数的其余部分。
将这些“参数”注册为模型buffer可以使pytorch跟踪它们并像常规参数一样保存它们,但是可以防止pytorch使用SGD机制更新它们。

用于缓冲的一个例子中可以找到_BatchNorm模块,其中running_meanrunning_varnum_batches_tracked通过累积通过所述层转发的数据的统计信息被登记为缓冲器和更新。这与使用常规SGD优化学习数据的仿射变换的参数weightbias参数相反。

  • 这就提出了一个问题:通过 [`register_buffer(name, tensor, permanent=False)`](https://pytorch.org/docs/stable/ generated/torch.nn.Module 注册的缓冲区的用例是什么。 html#torch.nn.Module.register_buffer),即,哪些甚至不是“state_dict”的一部分? (3认同)
  • @bluenote10 这很有意义。假设您有一个常量张量作为“nn.Module”的一部分(例如,位置嵌入等),并且您希望确保该张量被移动到适当的设备,并在每次调用“.”时转换为正确的“dtype”。 to()` 为你的`nn.Module`... (3认同)