torch.linalg.eig
- torch.linalg.eig(A, *, out=None)
-
如果可以,则计算方阵的特征值分解。
令 $\mathbb{K}$ 为实数集 $\mathbb{R}$ 或复数集 $\mathbb{C}$,一个 $n \times n$ 的方阵 $A \in \mathbb{K}^{n \times n}$ 的特征值分解(如果存在)定义为:
$A = V \operatorname{diag}(\Lambda) V^{-1}\mathrlap{\qquad V \in \mathbb{C}^{n \times n}, \Lambda \in \mathbb{C}^n}$当且仅当$A$是对角化矩阵时,这种分解存在。如果其所有特征值都不同,则这种情况成立。
支持浮点型、双精度型、复数浮点型和复数双精度型的数据类型作为输入。还支持矩阵的批量处理,如果
A
是一组矩阵,那么输出将具有相同的批处理维度。返回的特征值没有特定的顺序。
注意
实矩阵的特征值和特征向量可能为复数。
注意
当输入位于CUDA设备上时,此函数会将其与CPU进行同步。
警告
此函数假设
A
是对角化矩阵(例如,当所有特征值都互不相同时)。如果不能被对角化,则返回的特征值将是正确的,但$A \neq V \operatorname{diag}(\Lambda)V^{-1}$。警告
返回的特征向量会被归一化为范数1。即便如此,矩阵的特征向量并不是唯一的,并且它们对于
A
来说也不是连续变化的。由于这种唯一性的缺失,不同的硬件和软件可能会计算出不同的特征向量。这种非唯一性是由将特征向量乘以$e^{i \phi}, \phi \in \mathbb{R}$会产生另一组有效的矩阵特征向量所导致的。因此,损失函数不应依赖于特征向量的相位,因为这个量没有明确定义。在计算该函数梯度时会进行检查。当输入位于CUDA设备上时,此函数梯度的计算会将该设备与CPU同步。
警告
使用特征向量张量计算的梯度仅在
A
具有不同特征值时才是有限的。此外,如果任意两个特征值之间的距离接近于零,则梯度将变得数值不稳定,因为它的计算依赖于$\frac{1}{\min_{i \neq j} |\lambda_i - \lambda_j|}$。参见
torch.linalg.eigvals()
只计算特征值。与torch.linalg.eig()
不同,eigvals()
的梯度始终是数值稳定的。torch.linalg.eigh()
用于计算赫mitte矩阵和对称矩阵的特征值分解,速度快于其他方法。torch.linalg.svd()
用于计算任意形状矩阵的另一种类型的谱分解。torch.linalg.qr()
提供了一个更快的分解方法,适用于任何形状的矩阵。- 参数
-
A (Tensor) – 形状为(*, n, n)的张量,其中*表示零个或多个批次维度,包含可对角化的矩阵。
- 关键字参数
-
out (元组, 可选) – 由两个张量组成的输出元组。如果为None则忽略。默认值:None。
- 返回值
-
一个名为(eigenvalues, eigenvectors)的命名元组,对应于上述的$\Lambda$和$V$。
特征值 和 特征向量 即使在矩阵
A
为实数时也可能是复数。特征向量将由 特征向量 的列组成。
示例:
>>> A = torch.randn(2, 2, dtype=torch.complex128) >>> A tensor([[ 0.9828+0.3889j, -0.4617+0.3010j], [ 0.1662-0.7435j, -0.6139+0.0562j]], dtype=torch.complex128) >>> L, V = torch.linalg.eig(A) >>> L tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) >>> V tensor([[ 0.9218+0.0000j, 0.1882-0.2220j], [-0.0270-0.3867j, 0.9567+0.0000j]], dtype=torch.complex128) >>> torch.dist(V @ torch.diag(L) @ torch.linalg.inv(V), A) tensor(7.7119e-16, dtype=torch.float64) >>> A = torch.randn(3, 2, 2, dtype=torch.float64) >>> L, V = torch.linalg.eig(A) >>> torch.dist(V @ torch.diag_embed(L) @ torch.linalg.inv(V), A) tensor(3.2841e-16, dtype=torch.float64)