torch.triangular_solve
- torch.triangular_solve(b, A, upper=True, transpose=False, unitriangular=False, *, out=None)
-
求解一个方程组,其中方阵 $A$ 是上三角或下三角可逆矩阵,并且有多个右端向量 $b$。
用符号表示,它解决$AX = b$ 问题,并假设矩阵$A$ 是方阵且为上三角形(如果
upper
= False 则为下三角形),并且对角线上没有零元素。torch.triangular_solve(b, A) 可以接受二维输入 b, A,或者是一批二维矩阵的输入。如果是批量输入,则返回一批输出 X
如果矩阵A的对角线包含零或接近零的元素,并且
unitriangular
设置为 False(默认值),或者输入矩阵条件数很差,结果中可能包含NaN。支持浮点数、双精度浮点数、复数浮点数和复数双精度浮点数的输入。
警告
torch.triangular_solve()
已被torch.linalg.solve_triangular()
替代,并将在未来的 PyTorch 版本中移除。新的函数参数顺序相反,且不返回输入的副本。X = torch.linalg.solve_triangular(A, B)
应该替换为X = torch.linalg.solve_triangular(A, B)
- 参数
-
-
b (Tensor) – 多个右侧项,大小为 $(*, m, k)$,其中$*$表示零或多个批次维度。
-
A (Tensor) – 输入的三角系数矩阵,大小为$(*, m, m)$,其中$*$表示零个或多个批次维度。
-
upper (bool, optional) – 表示矩阵$A$ 是否为上三角或下三角。默认值:
True
。 -
transpose (bool, 可选) – 如果此标志为
True
,则解方程op(A)X = b 其中 op(A) = A^T; 如果是False
,则 op(A) = A。默认值:False
。 -
unitriangular (bool, 可选) – 指定矩阵$A$ 是否为单位三角矩阵。如果设置为 True,则认为$A$ 的对角线元素为 1,且不会从$A$ 中引用这些元素。默认值:
False
。
-
- 关键字参数
-
out ((Tensor, Tensor), 可选) – 用于存储输出结果的两个张量元组。如果未指定,则忽略此参数。默认值: None。
- 返回值
-
一个命名元组 (solution, cloned_coefficient),其中 cloned_coefficient 是矩阵 $A$ 的副本,而 solution 是方程组 $AX = b$(或其变体)的解 $X$。
示例:
>>> A = torch.randn(2, 2).triu() >>> A tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]) >>> b = torch.randn(2, 3) >>> b tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) >>> torch.triangular_solve(b, A) torch.return_types.triangular_solve( solution=tensor([[ 1.7841, 2.9046, -2.5405], [ 1.9320, 0.9270, -1.2826]]), cloned_coefficient=tensor([[ 1.1527, -1.0753], [ 0.0000, 0.7986]]))