广播语义
许多 PyTorch 操作支持 NumPy 的广播语义。详情请参阅 https://numpy.org/doc/stable/user/basics.broadcasting.html。
简而言之,如果一个 PyTorch 操作支持广播,那么它的张量参数可以自动扩展为相同大小(而无需复制数据)。
通用语义
两个张量是“可广播”的,如果以下规则成立:
-
每个张量至少包含一个维度。
-
当从尾部维度开始遍历维度大小时,这些维度的大小必须相等,或者其中一个为 1,或者其中一个不存在。
例如:
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 形状相同的张量总是可广播的(即上述规则始终成立)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x 和 y 不可广播,因为 x 没有至少一个维度
# 可以对齐尾部维度
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x 和 y 是可广播的。
# 第一个尾部维度:两者的大小都是 1
# 第二个尾部维度:y 的大小是 1
# 第三个尾部维度:x 的大小等于 y 的大小
# 第四个尾部维度:y 没有这一维度
# 但是:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
# x 和 y 不可广播,因为在第三个尾部维度上 2 不等于 3
如果两个张量 x
和 y
是“可广播”的,则结果张量的大小按以下方式计算:
- 如果
x
和y
的维度数不同,则在维度较少的张量前面补 1,以使它们的维度数相同。 - 然后,对于每个维度,结果张量的维度大小是
x
和y
在该维度上的最大值。
例如:
# 可以通过对齐尾部维度来使阅读更方便
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# 但这不是必须的:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: 张量 a (2) 的大小必须与张量 b (3) 在非单例维度 1 上匹配
原地操作语义
一个复杂之处在于,原地操作不允许因广播而改变原地张量的形状。
例如:
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# 但是:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: 扩展后的张量大小 (1) 必须与现有大小 (7) 在非单例维度 2 上一致。
向后兼容性
以前版本的 PyTorch 允许某些逐点函数在张量形状不同的情况下执行,只要每个张量中的元素数量相等。逐点操作会通过将每个张量视为一维数组来执行。现在 PyTorch 支持广播,‘一维’逐点行为已被弃用。如果张量不可广播但具有相同的元素数量,将会生成 Python 警告。
需要注意的是,引入广播功能后,如果两个张量形状不同但可广播且具有相同元素数量,可能会出现向后不兼容的更改。例如:
>>> torch.add(torch.ones(4,1), torch.randn(4))
以前会生成一个大小为 torch.Size([4,1]) 的张量,但现在会生成一个大小为 torch.Size([4,4]) 的张量。为了帮助识别因广播引入的向后不兼容情况,您可以将 torch.utils.backcompat.broadcast_warning.enabled
设置为 True
,这会在这些情况下生成 Python 警告。
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: 这两个张量的形状不同,但可以进行广播,并且具有相同的元素数量。
以不兼容旧版本的方式更改行为,从将它们视为一维数组改为进行广播。