torch.linalg.pinv

torch.linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) Tensor

计算矩阵的伪逆(摩尔-_penrose_逆)。

伪逆可以从代数角度定义,但通过奇异值分解(SVD)来理解它在计算上更方便。

支持浮点型、双精度型、复数浮点型和复数双精度型的数据类型作为输入。还支持矩阵的批量处理,如果A是一组矩阵,那么输出将具有相同的批处理维度。

如果 hermitian= True,假定矩阵A 是复数情况下的 Hermite 矩阵或实数情况下的对称矩阵,但不会在内部进行验证。相反,在计算中仅使用矩阵的下三角部分。

那些小于$\max(\text{atol}, \sigma_1 \cdot \text{rtol})$阈值的奇异值(或当hermitian为 True 时,特征值的范数)在计算中被视为零并被丢弃。其中$\sigma_1$是最大的奇异值(或特征值)。

如果未指定 rtol,并且 A 是一个维度为 (m, n) 的矩阵,则相对容差设置为 $\text{rtol} = \max(m, n) \varepsilon$。其中 $\varepsilon$A 数据类型的 epsilon 值(参见 finfo)。如果未指定 rtol,但指定了大于零的 atol,则将 rtol 设置为零。

如果 atolrtol 是一个torch.Tensor,其形状必须能够广播到由torch.linalg.svd() 返回的矩阵 A 的奇异值的形状。

注意

hermitianFalse时,此函数使用torch.linalg.svd();而当hermitianTrue时,则使用torch.linalg.eigh()。对于 CUDA 输入,此函数会将设备与 CPU 同步。

注意

如果可能的话,考虑使用torch.linalg.lstsq() 来实现左乘伪逆矩阵,具体如下:

torch.linalg.lstsq(A, B).solution == A.pinv() @ B

建议优先使用lstsq(),因为它比显式计算伪逆矩阵更快、更稳定。

注意

此函数有一个与 NumPy 兼容的变体 linalg.pinv(A, rcond, hermitian=False)。然而,位置参数 rcond 已被弃用,建议使用 rtol

警告

此函数内部使用torch.linalg.svd()(当hermitian = True时使用torch.linalg.eigh()),因此其导数存在与这些函数相同的问题。有关更多详细信息,请参阅torch.linalg.svd()torch.linalg.eigh()中的警告。

参见

torch.linalg.inv() 用于计算方阵的逆矩阵。

torch.linalg.lstsq() 使用数值稳定的算法来计算 A.pinv() @ B

参数
  • A (Tensor) – 形状为(*, m, n)的张量,其中*表示零个或多个批次维度。

  • rcond (float, Tensor, optional) – [NumPy 兼容]. rtol 的别名。默认值: None

关键字参数
  • atol (float, Tensor, 可选) – 绝对容差值。当为 None 时,默认认为是零。默认值: None

  • rtol (float, Tensor, 可选) – 相对容差值。默认情况下,当为None时,请参见上文中的说明。默认: None

  • hermitian (bool, 可选) – 表示矩阵 A 是否为复数情况下的 Hermitian 矩阵或实数情况下的对称矩阵。默认值: False

  • out (Tensor, optional) – 输出张量。默认为None,若未指定则忽略。

示例:

>>> A = torch.randn(3, 5)
>>> A
tensor([[ 0.5495,  0.0979, -1.4092, -0.1128,  0.4132],
        [-1.1143, -0.3662,  0.3042,  1.6374, -0.9294],
        [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]])
>>> torch.linalg.pinv(A)
tensor([[ 0.0600, -0.1933, -0.2090],
        [-0.0903, -0.0817, -0.4752],
        [-0.7124, -0.1631, -0.2272],
        [ 0.1356,  0.3933, -0.5023],
        [-0.0308, -0.1725, -0.5216]])

>>> A = torch.randn(2, 6, 3)
>>> Apinv = torch.linalg.pinv(A)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(8.5633e-07)

>>> A = torch.randn(3, 3, dtype=torch.complex64)
>>> A = A + A.T.conj()  # creates a Hermitian matrix
>>> Apinv = torch.linalg.pinv(A, hermitian=True)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(1.0830e-06)
本页目录