torch.bmm
- torch.bmm(input, mat2, *, out=None) → Tensor
-
对存储在
input
和mat2
中的矩阵进行批处理矩阵乘法运算。input
和mat2
都必须是三维张量,并且每个张量中包含的矩阵数量要相同。如果
input
是一个 $(b \times n \times m)$ 张量,mat2
是一个 $(b \times m \times p)$ 张量,那么out
将是一个 $(b \times n \times p)$ 张量。$\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i$该操作支持TensorFloat32。
在某些ROCm设备上,当使用float16输入时,此模块会采用不同的精度进行反向传播。
注意
此函数不执行广播。有关矩阵乘法的广播,请参阅
torch.matmul()
。示例:
>>> input = torch.randn(10, 3, 4) >>> mat2 = torch.randn(10, 4, 5) >>> res = torch.bmm(input, mat2) >>> res.size() torch.Size([10, 3, 5])