Causal Variant

torch.nn.attention.bias.CausalVariant(value)[源代码]

用于注意力机制的因果变异枚举。

定义了两种类型的因果偏见:

UPPER_LEFT: 表示标准因果注意力的上左三角偏置。构建该偏置的等效 PyTorch 代码如下:

torch.tril(torch.ones(size, dtype=torch.bool))

例如,当 shape=(3,4) 时,生成的偏置张量将会是:

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0]]

LOWER_RIGHT: 表示右下三角偏置,其中的值与矩阵的右下角对齐。

构建此偏置的等效PyTorch代码为:

diagonal_offset = size[1] - size[0]
torch.tril(
    torch.ones(size, dtype=torch.bool),
    diagonal=diagonal_offset,
)

例如,当 shape=(3,4) 时,生成的偏置张量将会是:

[[1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

需要注意的是,当查询和键值张量的序列长度相等时,这些变体是等价的,因为此时的三角矩阵是一个方阵。

警告

这是一个原型枚举,可能还会更改。

本页目录