torch.einsum
- torch.einsum(equation, *operands) → Tensor[源代码]
-
沿使用基于爱因斯坦求和约定的记号指定的维度,计算输入
operands
元素乘积的总和。Einsum 允许通过基于爱因斯坦求和约定的简写格式来计算许多常见的多维线性代数数组操作。具体来说,这种格式涉及用一些下标标记输入
operands
的每个维度,并定义哪些下标属于输出。然后,沿着不属于输出的维度计算operands
元素乘积的和来得到最终结果。例如,矩阵乘法可以使用 einsum 以 torch.einsum(“ij,jk->ik”, A, B) 的形式进行计算,在这里 j 是求和下标,而 i 和 k 则是输出下标(详情请参见下面的部分)。方程:
The
equation
字符串指定了每个输入operands
维度的下标([a-zA-Z] 中的字母),顺序与维度相同,用逗号(‘,’)分隔每个操作数的下标。例如,‘ij,jk’ 指定了两个 2D 操作数的下标。带有相同下标的维度必须可以广播,也就是说,它们的大小要么匹配,要么为1。例外情况是如果同一个输入操作数的下标被重复使用,则该操作数在这些维度上的尺寸必须相等,并且该操作数将被其沿这些维度的对角线替换。equation
中恰好出现一次的下标将是输出的一部分,按字母顺序升序排列。输出是通过基于下标的维度对齐输入operands
的元素进行逐元素相乘,并将不属于输出部分的维度上的元素求和来计算得出。可选地,可以通过在方程式末尾添加箭头(“->”)并跟随输出下标来显式定义输出下标。例如,以下方程计算矩阵乘法的转置:“ij,jk->ki”。输出下标必须至少在一个输入操作数中出现一次,并且在输出结果中只能出现一次。
可以使用省略号(“…”)来表示未明确指定的维度。每个输入操作数最多只能包含一个省略号,该省略号覆盖未被下标表示的维度。例如,在具有5个维度的输入操作数中,方程‘ab…c’中的省略号覆盖了第三和第四维度。虽然不同的
operands
之间省略号不需要覆盖相同的维度数量,但它们所覆盖的维度大小必须能够广播在一起。如果没有使用箭头(“->”)符号显式定义输出,则输出中的省略号将出现在最左边的位置,在输入操作数中恰好出现一次的下标标签之前。例如,方程‘…ij,…jk’实现了批量矩阵乘法。一些最后的注意事项:方程式中的不同元素(下标、省略号、箭头和逗号)之间可以包含空格,但类似‘…’的形式是无效的。对于标量运算符,空字符串‘’是有效的。
注意
torch.einsum
处理省略号(‘...’)的方式与 NumPy 不同,它允许对省略号所涵盖的维度进行求和操作,而省略号不必成为输出的一部分。注意
此函数使用 opt_einsum (https://optimized-ein-sum.readthedocs.io/en/stable/) 来通过优化收缩顺序来加速计算或减少内存消耗。当至少有三个输入时,才会进行此优化,因为在这种情况下顺序无关紧要。需要注意的是,找到最优路径是一个NP难问题,因此 opt_einsum 依赖于不同的启发式方法来实现接近最优的结果。如果 opt_einsum 不可用,则默认的收缩顺序是从左到右。
要绕过此默认行为,请添加以下行以禁用 opt_einsum 的使用并跳过路径计算: torch.backends.opt_einsum.enabled = False
要指定你希望 opt_einsum 使用哪种策略来计算收缩路径,请添加以下行:torch.backends.opt_einsum.strategy = 'auto'。默认策略是 'auto',我们还支持 'greedy' 和 'optimal' 策略。需要注意的是,'optimal' 的运行时间与输入数量的阶乘成正比!更多详情请参阅 opt_einsum 文档 (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。
注意
从 PyTorch 1.10 开始,
torch.einsum()
支持子列表格式(见下面的示例)。在这种格式中,每个操作数的下标由一个包含整数(范围在 [0, 52) 内)的子列表指定。这些子列表紧跟在其对应的操作数之后,并且可以在输入末尾添加一个额外的子列表来指定输出的下标,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 对象可以用于子列表中以启用广播,如上文方程部分所述。示例:
>>> # trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])