torch.linalg.householder_product

torch.linalg.householder_product(A, tau, *, out=None) Tensor

计算Householder矩阵乘积的前n个列。

\(\mathbb{K}\) 为实数集 \(\mathbb{R}\) 或复数集 \(\mathbb{C}\),并设矩阵 \(A\) 属于 \(\mathbb{K}^{m \times n}\),其列向量为 \(a_i \in \mathbb{K}^m\)(对于 \(i=1,\ldots,n\)),且满足 \(m \geq n\)。记由将 \(a_i\) 的前 \(i-1\) 个分量置零并将第 \(i\) 个分量设为 1 得到的向量为 \(b_i\)。对于一个长度不超过 \(n\) 的向量 \(\tau \in \mathbb{K}^k\),此函数计算矩阵的前 \(n\) 列。

$H_1H_2 ... H_k \qquad\text{with}\qquad H_i = \mathrm{I}_m - \tau_i b_i b_i^{\text{H}}$

其中,$\mathrm{I}_m$m 维的单位矩阵;当$b$ 为复数时,$b^{\text{H}}$ 表示共轭转置;当$b$ 为实数时,则表示普通转置。输出矩阵与输入矩阵 A 的大小相同。

关于正交或酉矩阵的表示方法的详细信息,请参见正交或酉矩阵的表示法

支持浮点型、双精度型、复数浮点型和复数双精度型的数据类型。还支持矩阵批处理,如果输入是矩阵批处理,则输出将具有相同的批处理维度。

参见

torch.geqrf() 可以与该函数一起使用,从 qr() 分解中提取 Q

torch.ormqr() 是一个相关的函数,它计算 Householder 矩阵乘积与另一个矩阵的矩阵乘法。但是,此函数不受 autograd 支持。

警告

只有当$\tau_i \neq \frac{1}{||a_i||^2}$时,梯度计算才有明确的定义。如果不满足此条件,则不会抛出错误,但生成的梯度可能包含NaN

参数
  • A (Tensor) – 形状为(*, m, n)的张量,其中*表示零个或多个批次维度。

  • tau (Tensor) – 形状为 (*, k) 的张量,其中 * 表示零个或多个批次维度。

关键字参数

out (Tensor, optional) – 输出张量。默认为None,若未指定则忽略。

异常

RuntimeError – 如果 A 不满足条件 m >= n,或者 tau 不满足条件 n >= k

示例:

>>> A = torch.randn(2, 2)
>>> h, tau = torch.geqrf(A)
>>> Q = torch.linalg.householder_product(h, tau)
>>> torch.dist(Q, torch.linalg.qr(A).Q)
tensor(0.)

>>> h = torch.randn(3, 2, 2, dtype=torch.complex128)
>>> tau = torch.randn(3, 1, dtype=torch.complex128)
>>> Q = torch.linalg.householder_product(h, tau)
>>> Q
tensor([[[ 1.8034+0.4184j,  0.2588-1.0174j],
        [-0.6853+0.7953j,  2.0790+0.5620j]],

        [[ 1.4581+1.6989j, -1.5360+0.1193j],
        [ 1.3877-0.6691j,  1.3512+1.3024j]],

        [[ 1.4766+0.5783j,  0.0361+0.6587j],
        [ 0.6396+0.1612j,  1.3693+0.4481j]]], dtype=torch.complex128)
本页目录