torch.nn.utils.convert_conv3d_weight_memory_format
- torch.nn.utils.convert_conv3d_weight_memory_format(module, memory_format)[源代码]
-
将
nn.Conv3d.weight
的memory_format
转换为目标格式。此转换递归应用于嵌套的nn.Module
,包括module
。请注意,它仅更改内存格式而不改变每个维度的语义。此函数用于方便计算采用NHWC内核,这为具有计算能力大于等于7.0的CUDA设备上的fp16数据提供了显著加速。注意
调用
model.to(memory_format=torch.channels_last_3d)
比实用函数convert_conv3d_weight_memory_format
更具侵略性。任何具有 4 维权重的层都会受到model.to
的影响,这不一定能从中受益。我们确信的是,在 cuDNN 中进行卷积的 NDHWC(channels_last_3d)转换是有益的,即使在需要对输入张量应用置换的情况下。因此,我们的策略是只将卷积权重转换为channels_last_3d。这样可以确保:1. 使用快速卷积内核,其性能提升可能超过因输入格式不同而产生的置换开销;2. 不会对不会从内存格式转换中受益的层进行不必要的置换操作。
最佳情况是,卷积层之间的层次都是通道最后兼容的。当输入张量遇到第一个卷积层时,会将其转换为通道最后的格式,并保持这种内存布局不变。因此,后续的卷积操作将不再需要转换其输入张量。
当通道最后的不兼容层位于两个卷积层之间时,需要将输入张量重新排列为连续格式以供该层使用。之后,输入张量将以连续格式通过剩余的所有层,并在遇到另一个卷积层时再次被重新排列为通道最后格式。由于大多数层对
memory_format
并不敏感,因此将这种重新排列传播到早期的层是没有意义的。当 PyTorch 支持置换融合时,这一说法可能会改变,因为除了在卷积前立即进行置换融合之外,可能存在更好的融合位置。
- 参数
-
-
module (nn.Module) –
nn.Conv3d
和nn.ConvTranspose3d
或容器nn.Module
-
memory_format — 用户指定的内存格式,例如
torch.channels_last
或torch.contiguous_format
-
- 返回值
-
包含更新后
nn.Conv3d
的原始模块
示例
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") >>> model = nn.Sequential( >>> nn.Conv3d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> out = model(input)