广播语义

许多 PyTorch 操作支持 NumPy 的广播语义。详情请参阅 https://numpy.org/doc/stable/user/basics.broadcasting.html

简而言之,如果一个 PyTorch 操作支持广播,那么它的张量参数可以自动扩展为相同大小(而无需复制数据)。

通用语义

两个张量是“可广播”的,如果以下规则成立:

  • 每个张量至少包含一个维度。

  • 当从尾部维度开始遍历维度大小时,这些维度的大小必须相等,或者其中一个为 1,或者其中一个不存在。

例如:

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 形状相同的张量总是可广播的(即上述规则始终成立)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x 和 y 不可广播,因为 x 没有至少一个维度

# 可以对齐尾部维度
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x 和 y 是可广播的。
# 第一个尾部维度:两者的大小都是 1
# 第二个尾部维度:y 的大小是 1
# 第三个尾部维度:x 的大小等于 y 的大小
# 第四个尾部维度:y 没有这一维度

# 但是:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x 和 y 不可广播,因为在第三个尾部维度上 2 不等于 3

如果两个张量 xy 是“可广播”的,则结果张量的大小按以下方式计算:

  • 如果 xy 的维度数不同,则在维度较少的张量前面补 1,以使它们的维度数相同。
  • 然后,对于每个维度,结果张量的维度大小是 xy 在该维度上的最大值。

例如:

# 可以通过对齐尾部维度来使阅读更方便
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# 但这不是必须的:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: 张量 a (2) 的大小必须与张量 b (3) 在非单例维度 1 上匹配

原地操作语义

一个复杂之处在于,原地操作不允许因广播而改变原地张量的形状。

例如:

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# 但是:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: 扩展后的张量大小 (1) 必须与现有大小 (7) 在非单例维度 2 上一致。

向后兼容性

以前版本的 PyTorch 允许某些逐点函数在张量形状不同的情况下执行,只要每个张量中的元素数量相等。逐点操作会通过将每个张量视为一维数组来执行。现在 PyTorch 支持广播,‘一维’逐点行为已被弃用。如果张量不可广播但具有相同的元素数量,将会生成 Python 警告。

需要注意的是,引入广播功能后,如果两个张量形状不同但可广播且具有相同元素数量,可能会出现向后不兼容的更改。例如:

>>> torch.add(torch.ones(4,1), torch.randn(4))

以前会生成一个大小为 torch.Size([4,1]) 的张量,但现在会生成一个大小为 torch.Size([4,4]) 的张量。为了帮助识别因广播引入的向后不兼容情况,您可以将 torch.utils.backcompat.broadcast_warning.enabled 设置为 True,这会在这些情况下生成 Python 警告。

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: 这两个张量的形状不同,但可以进行广播,并且具有相同的元素数量。
以不兼容旧版本的方式更改行为,从将它们视为一维数组改为进行广播。
本页目录