torch.chain_matmul
- torch.chain_matmul(*matrices, out=None)[源代码]
-
返回$N$个2-D张量的矩阵乘积。该乘积通过矩阵链顺序算法高效计算,选择在算术运算成本最低的情况下进行操作的顺序([CLRS])。需要注意的是,由于这是一个用于计算乘积的函数,因此$N$需要大于或等于2;如果$N$为2,则返回一个简单的矩阵-矩阵乘积。若$N$为1,此操作无意义 - 返回原始矩阵。
警告
torch.chain_matmul()
已被弃用,并将在未来的 PyTorch 版本中移除。建议使用torch.linalg.multi_dot()
,它接受一个包含两个或更多张量的列表作为参数。- 参数
-
-
矩阵 (张量...) – 包含两个或更多二维张量的序列,它们的乘积待确定。
-
out (Tensor, 可选) – 输出张量。如果
out
为None
,则忽略。
-
- 返回值
-
如果第 $i^{th}$ 张量的维度是 $p_{i} \times p_{i + 1}$,那么它们相乘后的结果将具有维度 $p_{1} \times p_{N + 1}$。
- 返回类型
示例:
>>> a = torch.randn(3, 4) >>> b = torch.randn(4, 5) >>> c = torch.randn(5, 6) >>> d = torch.randn(6, 7) >>> # will raise a deprecation warning >>> torch.chain_matmul(a, b, c, d) tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])