torch.matmul
- torch.matmul(input, other, *, out=None) → Tensor
-
两个张量的矩阵相乘结果。
行为根据张量的维度如下:
-
如果两个张量都是一维的,则返回它们的点积(标量)。
-
如果两个参数都是二维的,则返回矩阵乘积。
-
如果第一个参数是一维的,而第二个参数是二维的,则在进行矩阵乘法时,会为其维度前加上一个1。完成矩阵乘法后,这个额外添加的维度会被移除。
-
如果第一个参数是二维的,第二个参数是一维的,则返回矩阵与向量的乘积。
-
如果两个参数都是至少一维的,并且其中至少有一个是N维(N > 2),则返回批量矩阵乘法。如果第一个参数是一维的,则在进行批量矩阵乘法时在其维度前添加一个1,操作完成后移除该1。同样地,如果第二个参数是一维的,在其维度后添加一个1以进行批量矩阵乘法,操作完成后也移除该1。非矩阵(即批量)维度会被广播(因此必须是可广播的)。例如,如果
input
是一个$(j \times 1 \times n \times n)$ 张量而other
是一个$(k \times n \times n)$ 张量,那么out
将是一个$(j \times k \times n \times n)$ 张量。请注意,在确定输入是否可以广播时,广播逻辑仅考虑批次维度,而不考虑矩阵维度。例如,如果
input
是一个 $(j \times 1 \times n \times m)$ 张量而other
是一个 $(k \times m \times p)$ 张量,尽管最后两个维度(即矩阵维度)不同,这些输入仍然可以进行广播。out
将是一个 $(j \times k \times n \times p)$ 张量。
此操作支持具有稀疏布局的参数。特别是矩阵-矩阵运算(两个参数均为二维)支持与
torch.mm()
相同限制条件的稀疏参数。警告
稀疏支持是一项 beta 功能,某些布局、数据类型和设备的组合可能不被支持,或不具备自动微分功能。如果你发现了缺失的功能,请提交一个特性请求。
该操作支持TensorFloat32。
在某些ROCm设备上,当使用float16输入时,此模块会采用不同的精度进行反向传播。
注意
此函数的一维点积版本不支持
out
参数。示例:
>>> # vector x vector >>> tensor1 = torch.randn(3) >>> tensor2 = torch.randn(3) >>> torch.matmul(tensor1, tensor2).size() torch.Size([]) >>> # matrix x vector >>> tensor1 = torch.randn(3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([3]) >>> # batched matrix x broadcasted vector >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3]) >>> # batched matrix x batched matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(10, 4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5]) >>> # batched matrix x broadcasted matrix >>> tensor1 = torch.randn(10, 3, 4) >>> tensor2 = torch.randn(4, 5) >>> torch.matmul(tensor1, tensor2).size() torch.Size([10, 3, 5])
-