torch.linalg.lstsq
- torch.linalg.lstsq(A, B, rcond=None, *, driver=None)
-
求解线性方程组的最小二乘问题。
令 $\mathbb{K}$ 为 $\mathbb{R}$ 或 $\mathbb{C}$,对于线性系统 $AX = B$(其中 $A \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k}$),最小二乘问题定义如下:
$\min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F$其中$\|-\|_F$ 表示 Frobenius 范数。
支持浮点型、双精度型、复数浮点型和复数双精度型的数据类型。还支持矩阵批处理,如果输入是矩阵批处理,则输出将具有相同的批处理维度。
driver
选择将要使用的后端函数。对于 CPU 输入,有效值包括'gels'、'gelsy'、'gelsd' 和 'gelss'。要在 CPU 上选择最佳驱动程序,请考虑:-
如果矩阵
A
的条件数不是很大,或者你能够接受一些精度损失。-
对于一般矩阵:使用‘gelsy’ (带枢转的 QR 分解,默认选项)
-
如果
A
是满秩的,则使用‘gels’(QR方法)
-
-
如果
A
不是良态矩阵。-
‘gelsd’ (三对角化减少和奇异值分解功能)
-
但如果遇到内存问题:使用 ‘gelss’ (完整的 SVD)。
-
对于 CUDA 输入,唯一的有效选项是 ‘gels’,它假定矩阵
A
是满秩的。请参阅这些驱动程序的完整描述
rcond
用于确定矩阵A
在driver
是 ('gelsy', 'gelsd', 'gelss') 模式下的有效秩。在这种情况下,如果 $\sigma_i$ 是矩阵 A 的按降序排列的奇异值,则当 $\sigma_i \leq \text{rcond} \cdot \sigma_1$ 时,$\sigma_i$ 将被舍入为零。如果rcond
= None(默认值),则将rcond
设置为矩阵A
的数据类型机器精度乘以 max(m, n)。此函数以一个包含四个张量的命名元组形式返回问题的解决方案及一些额外信息:(solution, residuals, rank, singular_values)。对于输入
A
和B
,它们分别具有形状(*, m, n)和(*, m, k),具体包含如下内容:-
解决方案:最小二乘解。其形状为(*, n, k)。
-
residuals: 解的平方残差,即 $\|AX - B\|_F^2$。其形状与
A
的批次维度相同。当m > n 且A
中的每个矩阵都是满秩时,才会进行计算;否则返回一个空张量。如果A
是一组矩阵,并且其中任一矩阵不是满秩,则返回一个空张量。此行为在未来 PyTorch 版本中可能会发生变化。 -
rank: 表示
A
中矩阵秩的张量。其形状与A
的批次维度相同。当driver
为'gelsy'、'gelsd'或'gelss'之一时进行计算,否则它是一个空张量。 -
singular_values: 表示矩阵
A
的奇异值张量。其形状为(*, min(m, n))。当driver
是‘gelsd’或‘gelss’时进行计算,否则它是一个空张量。
注意
此函数以更快且数值更稳定的方式计算 X =
A
.pinverse() @B
,而不是单独执行这些计算。警告
未来版本的 PyTorch 中
rcond
的默认值可能会改变,因此建议使用固定值以避免潜在的不兼容问题。- 参数
- 关键字参数
-
driver (str, 可选) – 指定要使用的 LAPACK/MAGMA 方法的名称。如果未指定,默认情况下,对于 CPU 输入使用 ‘gelsy’,对于 CUDA 输入使用 ‘gels’。
- 返回值
-
一个名为(solution, residuals, rank, singular_values)的命名元组。
示例:
>>> A = torch.randn(1,3,3) >>> A tensor([[[-1.0838, 0.0225, 0.2275], [ 0.2438, 0.3844, 0.5499], [ 0.1175, -0.9102, 2.0870]]]) >>> B = torch.randn(2,3,3) >>> B tensor([[[-0.6772, 0.7758, 0.5109], [-1.4382, 1.3769, 1.1818], [-0.3450, 0.0806, 0.3967]], [[-1.3994, -0.1521, -0.1473], [ 1.9194, 1.0458, 0.6705], [-1.1802, -0.9796, 1.4086]]]) >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) >>> torch.dist(X, torch.linalg.pinv(A) @ B) tensor(1.5152e-06) >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values >>> torch.dist(S, torch.linalg.svdvals(A)) tensor(2.3842e-07) >>> A[:, 0].zero_() # Decrease the rank of A >>> rank = torch.linalg.lstsq(A, B).rank >>> rank tensor([2])
-