torch.nn.utils.rnn.pack_padded_sequence
- torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[源代码]
-
将包含不同长度填充序列的张量进行打包。
input
的尺寸可以是T x B x *
(如果batch_first
是False
),或者B x T x *
(如果batch_first
是True
)。其中,T
表示最长序列的长度,B
表示批量大小,而*
可以是任意数量的维度(包括 0)。对于未排序的序列,请将enforce_sorted设置为
False
。如果enforce_sorted
为True
,则序列应按长度降序排列,即input[:,0]
是最长的序列,而input[:,B-1]
是最短的序列。enforce_sorted = True仅在导出ONNX时需要。注意
此函数接受任何至少二维的输入。你可以将其应用于打包标签,并使用 RNN 的输出直接计算损失。可以通过访问
PackedSequence
对象的.data
属性来从张量中检索数据。- 参数
- 返回值
-
一个
PackedSequence
对象 - 返回类型