Causal Variant
- 类torch.nn.attention.bias.CausalVariant(value)[源代码]
-
用于注意力机制的因果变异枚举。
定义了两种类型的因果偏见:
UPPER_LEFT: 表示标准因果注意力的上左三角偏置。构建该偏置的等效 PyTorch 代码如下:
torch.tril(torch.ones(size, dtype=torch.bool))
例如,当 shape=(3,4) 时,生成的偏置张量将会是:
[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]
LOWER_RIGHT: 表示右下三角偏置,其中的值与矩阵的右下角对齐。
构建此偏置的等效PyTorch代码为:
diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, )
例如,当 shape=(3,4) 时,生成的偏置张量将会是:
[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]]
需要注意的是,当查询和键值张量的序列长度相等时,这些变体是等价的,因为此时的三角矩阵是一个方阵。
警告
这是一个原型枚举,可能还会更改。