torch.lu
- torch.lu(*args, **<kwargs)
-
计算矩阵或矩阵批次的 LU 分解
A
。返回一个包含矩阵A
的 LU 分解和选主元的元组。如果pivot
被设置为True
,则会进行选主元操作。警告
torch.lu()
已被建议使用torch.linalg.lu_factor()
和torch.linalg.lu_factor_ex()
替代。在未来的 PyTorch 版本中,torch.lu()
将被移除。LU, pivots, info = torch.lu(A, compute_pivots)
应该替换为LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
LU, pivots, info = torch.lu(A, pivot_type, get_infos=True)
应该被替换为LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
注意
-
对于批次中的每个矩阵,返回的置换矩阵由一个大小为
min(A.shape[-2], A.shape[-1])
的一索引向量表示。其中pivots[i] == j
表示在算法的第i
步中,第i
行与第j-1
行进行了置换。 -
当
pivot
设置为False
时,在 CPU 上进行 LU 分解是不可行的,这样做会抛出错误。但在 CUDA 上,则可以执行pivot
设为False
的 LU 分解。 -
如果
get_infos
为True
,该函数不会检查分解是否成功,因为分解的状态会在返回元组的第三个元素中显示。 -
对于在CUDA设备上大小为32或以下的方阵批次,由于MAGMA库中的一个bug(参见magma问题13),对于奇异矩阵会重复执行LU分解。
-
L
,U
和P
可以通过torch.lu_unpack()
获取。
警告
此函数的梯度仅在矩阵
A
为满秩时才存在。这是因为LU分解只在满秩矩阵上可微。此外,当A
接近不满秩时,由于依赖于$L^{-1}$和$U^{-1}$的计算,梯度会变得数值不稳定。- 参数
-
-
A (Tensor) – 大小为 $(*, m, n)$ 的需要分解的张量
-
pivot (bool, optional) – 控制是否进行.pivot操作。默认值:
True
-
get_infos (bool, optional) – 如果设置为
True
,则返回一个 info IntTensor。默认值:False
-
out (元组, 可选) – 可选输出的元组。如果
get_infos
为True
,则元组中的元素依次是 Tensor、IntTensor 和 IntTensor;若get_infos
为False
,则元组中的元素依次是 Tensor 和 IntTensor。默认值:None
-
- 返回值
-
包含张量的元组
-
因子分解(张量):大小为$(*, m, n)$的因子分解
-
pivots (IntTensor):大小为$(*, \text{min}(m, n))$的主元。
pivots
存储了所有中间行交换操作。可以通过应用swap(perm[i], perm[pivots[i] - 1])
对于i = 0, ..., pivots.size(-1) - 1
来重建perm
的最终排列,其中初始时perm
是$m$个元素的恒等排列(这基本上就是torch.lu_unpack()
所做的事情)。 -
infos (IntTensor, 可选):如果
get_infos
为True
,则这是一个大小为 $(*)$ 的张量,其中非零值表示矩阵或每个小批量的分解是否成功。
-
- 返回类型
-
(Tensor, IntTensor, (可选的) IntTensor)
示例:
>>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = torch.lu(A) >>> A_LU tensor([[[ 1.3506, 2.5558, -0.0816], [ 0.1684, 1.1551, 0.1940], [ 0.1193, 0.6189, -0.5497]], [[ 0.4526, 1.2526, -0.3285], [-0.7988, 0.7175, -0.9701], [ 0.2634, -0.9255, -0.3459]]]) >>> pivots tensor([[ 3, 3, 3], [ 3, 3, 3]], dtype=torch.int32) >>> A_LU, pivots, info = torch.lu(A, get_infos=True) >>> if info.nonzero().size(0) == 0: ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples!
-