torch.linalg.tensorinv
- torch.linalg.tensorinv(A, ind=2, *, out=None)) → Tensor
-
计算
torch.tensordot()
的乘法逆元素。如果 m 是
ind
维度的乘积,而 n 是其余维度的乘积,则此函数期望 m 和 n 相等。如果是这种情况,它会计算一个张量 X,使得 tensordot(A
, X,ind
) 在维度 m 上是单位矩阵。张量 X 将具有与A
相同的形状,但前ind
个维度会被移到末尾。X.shape == A.shape[ind:] + A.shape[:ind]
支持浮点型、双精度型、复数浮点型和复数双精度型数据类型的输入。
注意
当
A
是一个2维张量且ind
= 1时,此函数计算A
的乘法逆矩阵(参见torch.linalg.inv()
)。注意
如果可能的话,建议使用
torch.linalg.tensorsolve()
来实现用张量的逆从左侧乘以另一个张量,具体如下:linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # When B is a tensor with shape A.shape[:B.ndim]
建议优先使用
tensorsolve()
,因为它的速度更快且数值稳定性更好。参见
torch.linalg.tensorsolve()
计算torch.tensordot(torch.tensorinv(A), B)
。- 参数
-
-
A (Tensor) – 需要求逆的张量。其形状必须满足
prod(A.shape[:ind]) == prod(A.shape[ind:])
。 -
ind (int) – 计算
torch.tensordot()
逆的索引。默认值: 2。
-
- 关键字参数
-
out (Tensor, optional) – 输出张量。默认为None,若未指定则忽略。
- 异常
-
RuntimeError – 如果重塑后的
A
不可逆,或者前ind
个维度的乘积不等于其余维度的乘积。
示例:
>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) >>> Ainv = torch.linalg.tensorinv(A, ind=2) >>> Ainv.shape torch.Size([8, 3, 4, 6]) >>> B = torch.randn(4, 6) >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) True >>> A = torch.randn(4, 4) >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) >>> Ainv = torch.linalg.inv(A) >>> torch.allclose(Atensorinv, Ainv) True