BatchNorm2d 的 running_mean / running_var 在 PyTorch 中意味着什么?

yjy*_*131 3 mean variance deep-learning pytorch batch-normalization

我想知道我可以从running_mean和处running_var拨打电话nn.BatchNorm2d

示例代码在这里,其中 bn 表示nn.BatchNorm2d

vector = torch.cat([
    torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
    torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
    torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
    torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])
Run Code Online (Sandbox Code Playgroud)

我无法弄清楚Pytorch 官方文档和用户社区中的running_mean和是什么意思。running_var

nn.BatchNorm2.running_mean和是什么nn.BatchNorm2.running_var意思?

Iva*_*van 7

来自原始 Batchnorm 论文:

批量归一化:通过减少内部协变量偏移加速深度网络训练
Seguey IoffeChristian Szegedy ICML'2015

您可以在算法 1中看到如何测量给定批次的统计数据。

在此输入图像描述

然而,跨批次保存在内存中的是运行统计数据,在每个批次推理时迭代测量的统计数据。运行均值和运行方差的计算实际上在以下文档页面中得到了很好的解释nn.BatchNorm2d

在此输入图像描述

默认情况下,该momentum系数设置为0.1,它调节当前批次统计数据对运行统计数据的影响程度:

  • 更接近1意味着新的运行统计数据更接近当前批次统计数据,而

  • 更接近0意味着当前批次统计数据不会对更新新的运行统计数据做出太大贡献。

值得指出的是,它Batchnorm2d适用于空间维度,*此外*当然也适用于批量维度。给定一批形状(b, c, h, w),它将计算 的统计数据(b, h, w)。这意味着运行统计数据是成形的(c,)有与输入通道中一样多的统计组件(均值和方差)。

这是一个最小的例子:

>>> bn = nn.BatchNorm2d(10)
>>> x = torch.rand(2,10,2,2)
Run Code Online (Sandbox Code Playgroud)

由于默认track_running_stats设置为打开,因此在推断训练模式时它将跟踪跑步统计数据。True BatchNorm2d

运行均值和方差分别初始化为零和一。

>>> running_mean, running_var = torch.zeros(x.size(1)),torch.ones(x.size(1))
Run Code Online (Sandbox Code Playgroud)

让我们bn在训练模式下进行推理并检查其运行统计数据:

>>> bn(x)
>>> bn.running_mean, bn.running_var
(tensor([0.0650, 0.0432, 0.0373, 0.0534, 0.0476, 
         0.0622, 0.0651, 0.0660, 0.0406, 0.0446]),
 tensor([0.9027, 0.9170, 0.9162, 0.9082, 0.9087, 
         0.9026, 0.9136, 0.9043, 0.9126, 0.9122]))
Run Code Online (Sandbox Code Playgroud)

现在让我们手动计算这些统计数据:

>>> (1-momentum)*running_mean + momentum*xmean
tensor([[0.0650, 0.0432, 0.0373, 0.0534, 0.0476, 
         0.0622, 0.0651, 0.0660, 0.0406, 0.0446]])

>>> (1-momentum)*running_var + momentum*xvar
tensor([[0.9027, 0.9170, 0.9162, 0.9082, 0.9087, 
         0.9026, 0.9136, 0.9043, 0.9126, 0.9122]])
Run Code Online (Sandbox Code Playgroud)