torch.nn.functional.embedding
- torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)[源代码]
-
生成一个简单的查找表,在固定字典和大小的范围内查找嵌入。
此模块常用于通过索引检索词嵌入。其输入包括一个索引列表和一个嵌入矩阵,输出则是对应的词嵌入。
更多详情请参见
torch.nn.Embedding
。注意
请注意,此函数相对于
weight
中由padding_idx
指定的行进行解析求导的结果可能与数值求导的结果不同。注意
请注意,:class:`torch.nn.Embedding` 与该函数的不同之处在于:它在构建时将由
padding_idx
指定的weight
的行初始化为全零。- 参数
-
-
输入 (LongTensor) – 包含嵌入矩阵索引的张量
-
weight (Tensor) – 嵌入矩阵,其行数为最大可能索引加1,列数为嵌入维度大小。
-
padding_idx (int, 可选) – 如果指定了
padding_idx
,则该索引位置的条目不会影响梯度计算;因此,在训练过程中,padding_idx
处的嵌入向量保持不变,即它作为一个固定的“填充”值。 -
max_norm (float, optional) – 如果提供,每个范数大于
max_norm
的嵌入向量将被重新规范化为范数max_norm
。注意:这将会就地修改weight
。 -
norm_type (float, optional) – 用于计算
max_norm
选项的p范数中的 p值。默认为2
。 -
scale_grad_by_freq (bool, 可选) – 如果设置,将通过最小批量中单词频率的逆来缩放梯度。默认值为
False
。 -
sparse (bool, 可选) – 如果为
True
,weight
的梯度将是一个稀疏张量。有关稀疏梯度的更多详细信息,请参见torch.nn.Embedding
中的注释。
-
- 返回类型
- 形状:
-
-
输入: 任意形状的 LongTensor,包含要提取的索引
-
权重:一个形状为(V, embedding_dim)的浮点类型嵌入矩阵,其中 V 表示最大索引加 1,而 embedding_dim 是嵌入维度。
-
输出: (*, embedding_dim),其中 * 表示输入形状
-
示例:
>>> # a batch of 2 samples of 4 indices each >>> input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) >>> # an embedding matrix containing 10 tensors of size 3 >>> embedding_matrix = torch.rand(10, 3) >>> F.embedding(input, embedding_matrix) tensor([[[ 0.8490, 0.9625, 0.6753], [ 0.9666, 0.7761, 0.6108], [ 0.6246, 0.9751, 0.3618], [ 0.4161, 0.2419, 0.7383]], [[ 0.6246, 0.9751, 0.3618], [ 0.0237, 0.7794, 0.0528], [ 0.9666, 0.7761, 0.6108], [ 0.3385, 0.8612, 0.1867]]]) >>> # example with padding_idx >>> weights = torch.rand(10, 3) >>> weights[0, :].zero_() >>> embedding_matrix = weights >>> input = torch.tensor([[0, 2, 0, 5]]) >>> F.embedding(input, embedding_matrix, padding_idx=0) tensor([[[ 0.0000, 0.0000, 0.0000], [ 0.5609, 0.5384, 0.8720], [ 0.0000, 0.0000, 0.0000], [ 0.6262, 0.2438, 0.7471]]])