torch.tril

torch.tril(input, diagonal=0, *, out=None) Tensor

返回矩阵(二维张量)或矩阵批次 input 的下三角部分,并将结果张量 out 中的其他元素设为 0。

矩阵的下三角部分是指对角线及其下方的所有元素。

参数diagonal 控制要选择的对角线。如果 diagonal = 0,则保留主对角线及其以下的所有元素。正值表示包括主对角线上方同样数量的对角线,负值则表示排除主对角线下方同样数量的对角线。主对角线上的索引集合为 $\lbrace (i, i) \rbrace$,其中$i \in [0, \min\{d_{1}, d_{2}\} - 1]$$d_{1}, d_{2}$ 是矩阵的维度。

参数
  • input (Tensor) – 需要输入的张量。

  • diagonal (int, 可选) – 需要考虑的对角线

关键字参数

out (Tensor, 可选) – 指定输出张量。

示例:

>>> a = torch.randn(3, 3)
>>> a
tensor([[-1.0813, -0.8619,  0.7105],
        [ 0.0935,  0.1380,  2.2112],
        [-0.3409, -0.9828,  0.0289]])
>>> torch.tril(a)
tensor([[-1.0813,  0.0000,  0.0000],
        [ 0.0935,  0.1380,  0.0000],
        [-0.3409, -0.9828,  0.0289]])

>>> b = torch.randn(4, 6)
>>> b
tensor([[ 1.2219,  0.5653, -0.2521, -0.2345,  1.2544,  0.3461],
        [ 0.4785, -0.4477,  0.6049,  0.6368,  0.8775,  0.7145],
        [ 1.1502,  3.2716, -1.1243, -0.5413,  0.3615,  0.6864],
        [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024,  0.0978]])
>>> torch.tril(b, diagonal=1)
tensor([[ 1.2219,  0.5653,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.4785, -0.4477,  0.6049,  0.0000,  0.0000,  0.0000],
        [ 1.1502,  3.2716, -1.1243, -0.5413,  0.0000,  0.0000],
        [-0.0614, -0.7344, -1.3164, -0.7648, -1.4024,  0.0000]])
>>> torch.tril(b, diagonal=-1)
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.4785,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.1502,  3.2716,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0614, -0.7344, -1.3164,  0.0000,  0.0000,  0.0000]])
本页目录