torch.lu_solve

torch.lu_solve(b, LU_data, LU_pivots, *, out=None) Tensor

使用从lu_factor()获得的 A 的部分选主元 LU 分解,返回线性系统 $Ax = b$ 的 LU 解。

此函数支持floatdoublecfloatcdouble数据类型作为input的输入。

警告

torch.lu_solve() 已被弃用,建议使用torch.linalg.lu_solve()。在未来的 PyTorch 版本中将移除torch.lu_solve()X = torch.lu_solve(B, LU, pivots) 应该被替换为

X = linalg.lu_solve(LU, pivots, B)
参数
  • b (Tensor) – 右操作数张量,大小为$(*, m, k)$,其中$*$表示零个或多个批次维度。

  • LU_data (Tensor) – 矩阵 A 的带置换的 LU 分解,由 lu_factor() 提供,大小为$(*, m, m)$,其中$*$表示零个或多个批次维度。

  • LU_pivots (IntTensor) – 来自lu_factor() 的 LU 分解的枢轴,大小为$(*, m)$(其中$*$ 表示零个或多个批处理维度)。LU_pivots 的批处理维度必须与 LU_data 的批处理维度相等。

关键字参数

out (Tensor, 可选) – 指定输出张量。

示例:

>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3, 1)
>>> LU, pivots = torch.linalg.lu_factor(A)
>>> x = torch.lu_solve(b, LU, pivots)
>>> torch.dist(A @ x, b)
tensor(1.00000e-07 *
       2.8312)
本页目录