torch.linalg.tensorinv

torch.linalg.tensorinv(A, ind=2, *, out=None)) Tensor

计算torch.tensordot()的乘法逆元素。

如果 mind 维度的乘积,而 n 是其余维度的乘积,则此函数期望 mn 相等。如果是这种情况,它会计算一个张量 X,使得 tensordot(A, X, ind) 在维度 m 上是单位矩阵。张量 X 将具有与 A 相同的形状,但前 ind 个维度会被移到末尾。

X.shape == A.shape[ind:] + A.shape[:ind]

支持浮点型、双精度型、复数浮点型和复数双精度型数据类型的输入。

注意

A是一个2维张量且ind= 1时,此函数计算A的乘法逆矩阵(参见torch.linalg.inv())。

注意

如果可能的话,建议使用torch.linalg.tensorsolve() 来实现用张量的逆从左侧乘以另一个张量,具体如下:

linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B)  # When B is a tensor with shape A.shape[:B.ndim]

建议优先使用tensorsolve(),因为它的速度更快且数值稳定性更好。

参见

torch.linalg.tensorsolve() 计算 torch.tensordot(torch.tensorinv(A), B)

参数
  • A (Tensor) – 需要求逆的张量。其形状必须满足 prod(A.shape[:ind]) == prod(A.shape[ind:])

  • ind (int) – 计算 torch.tensordot() 逆的索引。默认值: 2

关键字参数

out (Tensor, optional) – 输出张量。默认为None,若未指定则忽略。

异常

RuntimeError – 如果重塑后的 A 不可逆,或者前 ind 个维度的乘积不等于其余维度的乘积。

示例:

>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3))
>>> Ainv = torch.linalg.tensorinv(A, ind=2)
>>> Ainv.shape
torch.Size([8, 3, 4, 6])
>>> B = torch.randn(4, 6)
>>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B))
True

>>> A = torch.randn(4, 4)
>>> Atensorinv = torch.linalg.tensorinv(A, ind=1)
>>> Ainv = torch.linalg.inv(A)
>>> torch.allclose(Atensorinv, Ainv)
True
本页目录