torch.chain_matmul

torch.chain_matmul(*matrices, out=None)[源代码]

返回$N$个2-D张量的矩阵乘积。该乘积通过矩阵链顺序算法高效计算,选择在算术运算成本最低的情况下进行操作的顺序([CLRS])。需要注意的是,由于这是一个用于计算乘积的函数,因此$N$需要大于或等于2;如果$N$为2,则返回一个简单的矩阵-矩阵乘积。若$N$为1,此操作无意义 - 返回原始矩阵。

警告

torch.chain_matmul() 已被弃用,并将在未来的 PyTorch 版本中移除。建议使用torch.linalg.multi_dot(),它接受一个包含两个或更多张量的列表作为参数。

参数
  • 矩阵 (张量...) – 包含两个或更多二维张量的序列,它们的乘积待确定。

  • out (Tensor, 可选) – 输出张量。如果 outNone,则忽略。

返回值

如果第 $i^{th}$ 张量的维度是 $p_{i} \times p_{i + 1}$,那么它们相乘后的结果将具有维度 $p_{1} \times p_{N + 1}$

返回类型

张量

示例:

>>> a = torch.randn(3, 4)
>>> b = torch.randn(4, 5)
>>> c = torch.randn(5, 6)
>>> d = torch.randn(6, 7)
>>> # will raise a deprecation warning
>>> torch.chain_matmul(a, b, c, d)
tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],
        [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],
        [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])
本页目录