torch.linalg.multi_dot

torch.linalg.multi_dot(tensors, *, out=None)

通过优化矩阵相乘的顺序,减少所需的算术运算次数,从而高效地计算两个或多个矩阵的乘积。

支持浮点型、双精度型、复数浮点型和复数双精度型的数据类型。此函数不支持批处理输入。

除了第一个和最后一个张量可以是一维的之外,tensors 中的所有张量都必须是二维的。如果第一个张量是一个形状为(n,)的一维向量,则将其视为形状为(1, n)的行向量;同样地,如果最后一个张量是一个形状为(n,)的一维向量,则将其视为形状为(n, 1)的列向量。

如果第一个和最后一个张量都是矩阵,那么输出也将是一个矩阵。然而,只要其中一个是一维向量,输出就会变成一维向量。

numpy.linalg.multi_dot的不同之处:

  • numpy.linalg.multi_dot不同,第一个和最后一个张量必须是1D或2D,而NumPy则允许它们是任意维度(nD)。

警告

此函数不会进行广播。

注意

此函数通过先计算最优的矩阵乘法顺序,然后链式调用torch.mm()来实现。

注意

形状为(a, b)(b, c)的两个矩阵相乘的成本是a * b * c。给定矩阵ABC,它们的形状分别为(10, 100)(100, 5)(5, 50),我们可以按照以下方式计算不同乘法顺序的成本:

$\begin{align*} \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 \end{align*}$

在这种情况下,先计算 AB 的乘积,然后再与 C 运算,速度会快10倍。

参数

tensors (Sequence[Tensor]) – 要相乘的两个或多个张量。第一个和最后一个张量可以是1D或2D,而其他所有张量必须为2D。

关键字参数

out (Tensor, optional) – 输出张量。默认为None,若未指定则忽略。

示例:

>>> from torch.linalg import multi_dot

>>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])])
tensor(8)
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])])
tensor([8])
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])])
tensor([[8]])

>>> A = torch.arange(2 * 3).view(2, 3)
>>> B = torch.arange(3 * 2).view(3, 2)
>>> C = torch.arange(2 * 2).view(2, 2)
>>> multi_dot((A, B, C))
tensor([[ 26,  49],
        [ 80, 148]])
本页目录