torch.linalg.multi_dot
- torch.linalg.multi_dot(tensors, *, out=None)
-
通过优化矩阵相乘的顺序,减少所需的算术运算次数,从而高效地计算两个或多个矩阵的乘积。
支持浮点型、双精度型、复数浮点型和复数双精度型的数据类型。此函数不支持批处理输入。
除了第一个和最后一个张量可以是一维的之外,
tensors
中的所有张量都必须是二维的。如果第一个张量是一个形状为(n,)的一维向量,则将其视为形状为(1, n)的行向量;同样地,如果最后一个张量是一个形状为(n,)的一维向量,则将其视为形状为(n, 1)的列向量。如果第一个和最后一个张量都是矩阵,那么输出也将是一个矩阵。然而,只要其中一个是一维向量,输出就会变成一维向量。
与numpy.linalg.multi_dot的不同之处:
-
与numpy.linalg.multi_dot不同,第一个和最后一个张量必须是1D或2D,而NumPy则允许它们是任意维度(nD)。
警告
此函数不会进行广播。
注意
此函数通过先计算最优的矩阵乘法顺序,然后链式调用
torch.mm()
来实现。注意
形状为(a, b)和(b, c)的两个矩阵相乘的成本是a * b * c。给定矩阵A、B和C,它们的形状分别为(10, 100)、(100, 5)和(5, 50),我们可以按照以下方式计算不同乘法顺序的成本:
$\begin{align*} \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 \end{align*}$在这种情况下,先计算 A 和 B 的乘积,然后再与 C 运算,速度会快10倍。
- 参数
-
tensors (Sequence[Tensor]) – 要相乘的两个或多个张量。第一个和最后一个张量可以是1D或2D,而其他所有张量必须为2D。
- 关键字参数
-
out (Tensor, optional) – 输出张量。默认为None,若未指定则忽略。
示例:
>>> from torch.linalg import multi_dot >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) tensor(8) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) tensor([8]) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) tensor([[8]]) >>> A = torch.arange(2 * 3).view(2, 3) >>> B = torch.arange(3 * 2).view(3, 2) >>> C = torch.arange(2 * 2).view(2, 2) >>> multi_dot((A, B, C)) tensor([[ 26, 49], [ 80, 148]])
-