torch.lu

torch.lu(*args, **<kwargs)

计算矩阵或矩阵批次的 LU 分解 A。返回一个包含矩阵 A 的 LU 分解和选主元的元组。如果 pivot 被设置为 True,则会进行选主元操作。

警告

torch.lu() 已被建议使用torch.linalg.lu_factor()torch.linalg.lu_factor_ex() 替代。在未来的 PyTorch 版本中,torch.lu() 将被移除。LU, pivots, info = torch.lu(A, compute_pivots) 应该替换为

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, pivot_type, get_infos=True) 应该被替换为

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 对于批次中的每个矩阵,返回的置换矩阵由一个大小为 min(A.shape[-2], A.shape[-1]) 的一索引向量表示。其中 pivots[i] == j 表示在算法的第 i 步中,第 i 行与第 j-1 行进行了置换。

  • pivot 设置为 False 时,在 CPU 上进行 LU 分解是不可行的,这样做会抛出错误。但在 CUDA 上,则可以执行 pivot 设为 False 的 LU 分解。

  • 如果 get_infosTrue,该函数不会检查分解是否成功,因为分解的状态会在返回元组的第三个元素中显示。

  • 对于在CUDA设备上大小为32或以下的方阵批次,由于MAGMA库中的一个bug(参见magma问题13),对于奇异矩阵会重复执行LU分解。

  • L, UP 可以通过torch.lu_unpack() 获取。

警告

此函数的梯度仅在矩阵A为满秩时才存在。这是因为LU分解只在满秩矩阵上可微。此外,当A接近不满秩时,由于依赖于$L^{-1}$$U^{-1}$的计算,梯度会变得数值不稳定。

参数
  • A (Tensor) – 大小为 $(*, m, n)$ 的需要分解的张量

  • pivot (bool, optional) – 控制是否进行.pivot操作。默认值: True

  • get_infos (bool, optional) – 如果设置为 True,则返回一个 info IntTensor。默认值:False

  • out (元组, 可选) – 可选输出的元组。如果 get_infosTrue,则元组中的元素依次是 Tensor、IntTensor 和 IntTensor;若 get_infosFalse,则元组中的元素依次是 Tensor 和 IntTensor。默认值: None

返回值

包含张量的元组

  • 因子分解张量):大小为$(*, m, n)$的因子分解

  • pivots (IntTensor):大小为$(*, \text{min}(m, n))$的主元。pivots 存储了所有中间行交换操作。可以通过应用 swap(perm[i], perm[pivots[i] - 1]) 对于 i = 0, ..., pivots.size(-1) - 1 来重建perm 的最终排列,其中初始时perm$m$个元素的恒等排列(这基本上就是torch.lu_unpack() 所做的事情)。

  • infos (IntTensor, 可选):如果 get_infosTrue,则这是一个大小为 $(*)$ 的张量,其中非零值表示矩阵或每个小批量的分解是否成功。

返回类型

(Tensor, IntTensor, (可选的) IntTensor)

示例:

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!
本页目录