GroupNorm 比 Pytorch 中的 BatchNorm 慢得多,并且消耗更高的 GPU 内存

zbh*_*047 5 pytorch

我在 pytorch 中使用 GroupNorm 而不是 BatchNorm 并保持所有其他(网络架构)不变。结果表明,在 Imagenet 数据集中,使用 resnet50 架构,GroupNorm 比 BatchNorm 慢 40%,并且比 BatchNorm 多消耗 33% 的 GPU 内存。我真的很困惑,因为 GroupNorm 不应该比 BatchNorm 需要更多的计算。详情如下。

有关 Group Normalization 的详细信息,可以参见这篇论文:https : //arxiv.org/pdf/1803.08494.pdf

对于 BatchNorm,GPU 内存为 7.51GB,一个 minibatch 耗时 12.8 秒;

对于 GroupNorm,一个 minibatch 消耗 17.9 秒,GPU 内存为 10.02GB。

我使用以下代码将所有 BatchNorm 层转换为 GroupNorm 层。

def convert_bn_model_to_gn(module, num_groups=16):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with :class:`torch.nn.GroupNorm`.
Args:
    module: your network module
    num_groups: num_groups of GN
"""
mod = module
if isinstance(module, nn.modules.batchnorm._BatchNorm):
    mod = nn.GroupNorm(num_groups, module.num_features,
                       eps=module.eps, affine=module.affine)
    # mod = nn.modules.linear.Identity()
    if module.affine:
        mod.weight.data = module.weight.data.clone().detach()
        mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
    mod.add_module(name, convert_bn_model_to_gn(
        child, num_groups=num_groups))
del module
return mod
Run Code Online (Sandbox Code Playgroud)

Rev*_*Gen 1

是的,你是对的,与 BN 相比,GN 确实使用了更多的资源。我猜这是因为它必须计算每组通道的均值和方差,而 BN 只需要在整个批次中计算一次。

但 GN 的优点是,您可以将 Batch Size 降低到 2,而不会降低任何性能,如论文中所述,因此您可以弥补计算开销。