torch.linalg.solve_triangular

torch.linalg.solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None)) Tensor

计算具有唯一解的三角形线性方程组的解决方案。

$\mathbb{K}$ 为实数集 $\mathbb{R}$ 或复数集 $\mathbb{C}$,此函数计算与三角矩阵 $A \in \mathbb{K}^{n \times n}$(对角线上没有零值,即它是可逆的)以及矩形矩阵 $B \in \mathbb{K}^{n \times k}$ 相关的线性系统的解 $X \in \mathbb{K}^{n \times k}$,其定义为:

$AX = B$

参数 upper 表示矩阵 $A$ 是上三角矩阵还是下三角矩阵。

如果 left = False,此函数返回矩阵 $X \in \mathbb{K}^{n \times k}$ 以解决该系统。

$XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.}$

如果 upper= True,则只访问矩阵 A 的上三角部分;反之,若upper= False,则只访问下三角部分。主对角线以下的元素将被视为零,并不会被访问。

如果 unitriangular = True,则假设矩阵 A 的对角线元素为 1,且不会访问这些元素。

如果A的对角线包含零或接近零的元素,并且unitriangular设为< cite>= False(默认设置),或者输入矩阵具有非常小的特征值,结果中可能会出现NaN

支持浮点型、双精度型、复数浮点型和复数双精度型的数据类型。还支持矩阵批处理,如果输入是矩阵批处理,则输出将具有相同的批处理维度。

参见

torch.linalg.solve() 用于计算一般方阵线性方程组的唯一解。

参数
  • A (Tensor) – 形状为(*, n, n)(如果leftFalse 则为(*, k, k))的张量,其中*表示零个或多个批次维度。

  • B (Tensor) – 形状为 (*, n, k) 的右侧张量。

关键字参数
  • upper (bool) – 表示矩阵 A 是上三角还是下三角矩阵。

  • left (bool, 可选) – 是否求解方程组 $AX=B$ 或者 $XA = B$。默认值: True

  • unitriangular (bool, 可选) – 如果为True,则假设矩阵A 的对角线元素全部为 1。默认值:False

  • out (Tensor, 可选) – 输出张量。B 可以作为 out 传递,并在 B 上原地计算结果。如果为 None,则忽略。默认值: None

示例:

>>> A = torch.randn(3, 3).triu_()
>>> B = torch.randn(3, 4)
>>> X = torch.linalg.solve_triangular(A, B, upper=True)
>>> torch.allclose(A @ X, B)
True

>>> A = torch.randn(2, 3, 3).tril_()
>>> B = torch.randn(2, 3, 4)
>>> X = torch.linalg.solve_triangular(A, B, upper=False)
>>> torch.allclose(A @ X, B)
True

>>> A = torch.randn(2, 4, 4).tril_()
>>> B = torch.randn(2, 3, 4)
>>> X = torch.linalg.solve_triangular(A, B, upper=False, left=False)
>>> torch.allclose(X @ A, B)
True
本页目录