torch.bmm

torch.bmm(input, mat2, *, out=None) → Tensor

对存储在inputmat2中的矩阵进行批处理矩阵乘法运算。

inputmat2 都必须是三维张量,并且每个张量中包含的矩阵数量要相同。

如果 input 是一个 $(b \times n \times m)$ 张量,mat2 是一个 $(b \times m \times p)$ 张量,那么 out 将是一个 $(b \times n \times p)$ 张量。

$\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i$

该操作支持TensorFloat32

在某些ROCm设备上,当使用float16输入时,此模块会采用不同的精度进行反向传播。

注意

此函数不执行广播。有关矩阵乘法的广播,请参阅 torch.matmul()

参数
  • 输入 (Tensor) – 需要相乘的第一批矩阵

  • mat2 (Tensor) – 第二批要相乘的矩阵

关键字参数

out (Tensor, 可选) – 指定输出张量。

示例:

>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])
本页目录