torch.nn.utils.rnn.pack_padded_sequence
- torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[源代码]
-
将包含不同长度填充序列的张量进行打包。
input的尺寸可以是T x B x *(如果batch_first是False),或者B x T x *(如果batch_first是True)。其中,T表示最长序列的长度,B表示批量大小,而*可以是任意数量的维度(包括 0)。对于未排序的序列,请将enforce_sorted设置为
False。如果enforce_sorted为True,则序列应按长度降序排列,即input[:,0]是最长的序列,而input[:,B-1]是最短的序列。enforce_sorted = True仅在导出ONNX时需要。注意
此函数接受任何至少二维的输入。你可以将其应用于打包标签,并使用 RNN 的输出直接计算损失。可以通过访问
PackedSequence对象的.data属性来从张量中检索数据。- 参数
- 返回值
-
一个
PackedSequence对象 - 返回类型