torch.nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[源代码]

将包含不同长度填充序列的张量进行打包。

input 的尺寸可以是 T x B x *(如果 batch_firstFalse),或者 B x T x *(如果 batch_firstTrue)。其中,T 表示最长序列的长度,B 表示批量大小,而 * 可以是任意数量的维度(包括 0)。

对于未排序的序列,请将enforce_sorted设置为False。如果enforce_sortedTrue,则序列应按长度降序排列,即input[:,0]是最长的序列,而input[:,B-1]是最短的序列。enforce_sorted = True仅在导出ONNX时需要。

注意

此函数接受任何至少二维的输入。你可以将其应用于打包标签,并使用 RNN 的输出直接计算损失。可以通过访问PackedSequence对象的.data属性来从张量中检索数据。

参数
  • 输入 (Tensor) – 可变长度序列的填充批量。

  • lengths (Tensorlist(int) – 每个批次元素的序列长度列表。如果以张量形式提供,该张量必须在CPU上。

  • batch_first (bool, optional) – 如果为True,输入的格式应为B x T x *,否则为T x B x *

  • enforce_sorted (bool, optional) – 如果为 True,则输入的数据序列应按长度递减顺序排列。如果为 False,则会无条件地对输入数据进行排序。默认值:True

返回值

一个 PackedSequence 对象

返回类型

PackedSequence

本页目录