torch.linalg.tensorsolve

torch.linalg.tensorsolve(A, B, dims=None, *, out=None) Tensor

求解系统 torch.tensordot(A, X) = B 中的 X

如果 mA 的前 B.ndim 个维度的乘积,而 n 是其余维度的乘积,那么此函数期望 mn 相等。

返回的张量 x 满足公式 tensordot(A, x, dims=x.ndim) == B。其中,x 的形状为 A[B.ndim:]

如果指定了dims,矩阵A将被重新塑形为

A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))

支持浮点型、双精度型、复数浮点型和复数双精度型数据类型。

参见

torch.linalg.tensorinv() 计算 torch.tensordot() 的乘法逆矩阵。

参数
  • A (Tensor) – 需要求解的张量。其形状必须满足 prod(A.shape[:B.ndim]) == prod(A.shape[B.ndim:])

  • B (Tensor) – 形状为 A.shape[:B.ndim] 的张量。

  • dims (Tuple[int], optional) – 指定要移动的 A 的维度。如果为None,则不进行任何维度移动。默认值:None

关键字参数

out (Tensor, optional) – 输出张量。默认为None,若未指定则忽略。

异常

RuntimeError – 如果重塑后的 A.view(m, m) 在上述条件下不可逆,或者前 ind 维度的乘积不等于其余维度的乘积。

示例:

>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
>>> B = torch.randn(2 * 3, 4)
>>> X = torch.linalg.tensorsolve(A, B)
>>> X.shape
torch.Size([2, 3, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B)
True

>>> A = torch.randn(6, 4, 4, 3, 2)
>>> B = torch.randn(4, 3, 2)
>>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2))
>>> X.shape
torch.Size([6, 4])
>>> A = A.permute(1, 3, 4, 0, 2)
>>> A.shape[B.ndim:]
torch.Size([6, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6)
True
本页目录