修复批量归一化

发生了什么事?

Batch Norm 要求 running_mean 和 running_var 的大小与输入相同,并对其进行原地更新。Functorch 不支持将批量张量传递给常规张量的原地更新操作(即 regular.add_(batched) 不被允许)。因此,在单个模块上对一批输入进行 vmapping 时,会出现此错误。

如何解决

其中一个最好的支持方法是将 BatchNorm 切换为 GroupNorm。选项 1 和选项 2 支持这一点。

所有这些选项都假设你不需要实时统计信息。如果你使用的是模块,则默认情况下不会在评估模式中启用批量归一化。如果需要在这种模式下使用vmap进行批量归一化,请提交一个问题。

选项 1:修改批处理规范化

如果你想改为GroupNorm,在原来使用BatchNorm的地方进行替换:

BatchNorm2d(C, G, track_running_stats=False)

在这里,C 与原始 BatchNorm 中的 C 相同。而 G 表示将 C 分成的组数。因此,需要满足条件 C % G == 0。如果无法满足该条件,则可以将 C 设置为等于 G,这样每个通道都会被单独处理。

如果你必须使用BatchNorm,并且自己构建了该模块,可以将其修改为不使用运行统计。具体来说,将所有BatchNorm模块的track_running_stats标志设置为False。

BatchNorm2d(64, track_running_stats=False)

选项 2:torchvision 参数

一些torchvision模型(如resnet和regnet)可以接受一个norm_layer参数,如果没有特别指定,默认值通常为BatchNorm2d

你可以将它设置为 GroupNorm。

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

在这里,再一次,当c % g == 0时,将g = c作为备选方案。

如果你依赖于BatchNorm,确保使用不计算运行统计的版本

import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

选项 3:使用 functorch 的补丁

functorch 增加了一些功能,允许快速、就地修改模块以禁用累积统计量的使用。更改归一化层较为脆弱,因此我们没有提供该选项。如果你希望某个网络中的 BatchNorm 层不使用累积统计量,可以运行 replace_all_batch_norm_modules_ 来就地更新模块。

from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)

选项 4:eval 模式

在评估模式下运行时,running_mean 和 running_var 不会更新。因此,vmap 可以支持这种模式。

model.eval()
vmap(model)(x)
model.train()
本页目录