torch.matmul

torch.matmul(input, other, *, out=None) Tensor

两个张量的矩阵相乘结果。

行为根据张量的维度如下:

  • 如果两个张量都是一维的,则返回它们的点积(标量)。

  • 如果两个参数都是二维的,则返回矩阵乘积。

  • 如果第一个参数是一维的,而第二个参数是二维的,则在进行矩阵乘法时,会为其维度前加上一个1。完成矩阵乘法后,这个额外添加的维度会被移除。

  • 如果第一个参数是二维的,第二个参数是一维的,则返回矩阵与向量的乘积。

  • 如果两个参数都是至少一维的,并且其中至少有一个是N维(N > 2),则返回批量矩阵乘法。如果第一个参数是一维的,则在进行批量矩阵乘法时在其维度前添加一个1,操作完成后移除该1。同样地,如果第二个参数是一维的,在其维度后添加一个1以进行批量矩阵乘法,操作完成后也移除该1。非矩阵(即批量)维度会被广播(因此必须是可广播的)。例如,如果input 是一个$(j \times 1 \times n \times n)$ 张量而 other 是一个$(k \times n \times n)$ 张量,那么out 将是一个$(j \times k \times n \times n)$ 张量。

    请注意,在确定输入是否可以广播时,广播逻辑仅考虑批次维度,而不考虑矩阵维度。例如,如果input 是一个 $(j \times 1 \times n \times m)$ 张量而other 是一个 $(k \times m \times p)$ 张量,尽管最后两个维度(即矩阵维度)不同,这些输入仍然可以进行广播。out 将是一个 $(j \times k \times n \times p)$ 张量。

此操作支持具有稀疏布局的参数。特别是矩阵-矩阵运算(两个参数均为二维)支持与torch.mm()相同限制条件的稀疏参数。

警告

稀疏支持是一项 beta 功能,某些布局、数据类型和设备的组合可能不被支持,或不具备自动微分功能。如果你发现了缺失的功能,请提交一个特性请求。

该操作支持TensorFloat32

在某些ROCm设备上,当使用float16输入时,此模块会采用不同的精度进行反向传播。

注意

此函数的一维点积版本不支持out参数。

参数
  • 输入 (Tensor) – 需要相乘的第一个张量

  • other (Tensor) – 需要与之相乘的第二个张量

关键字参数

out (Tensor, 可选) – 指定输出张量。

示例:

>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
本页目录