torch.nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)[源代码]

对一批长度不同的序列进行填充。

它是pack_padded_sequence()的逆操作。

返回的张量的数据形状将为 T x B x *(如果 batch_firstFalse)或 B x T x *(如果 batch_firstTrue),其中 T 表示最长序列的长度,B 表示批量大小。

示例

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])

注意

total_length 有助于在使用 DataParallel 包装的 Module 中实现 pack sequence -> recurrent network -> unpack sequence 模式。详情请参阅此 FAQ 部分

参数
  • sequence (PackedSequence) – 待填充的批次

  • batch_first (bool, optional) – 如果为True,输出格式将为B x T x *,否则为T x B x *

  • padding_value (float, 可选) – 用于填充元素的值。

  • total_length (int, optional) – 如果不为 None,输出将被填充至长度为 total_length。如果 total_length 小于 sequence 中的最大序列长度,则此方法会抛出 ValueError

返回值

返回一个元组,其中包含填充后的序列张量和一个存储每个序列长度的张量。批次中的元素会按照最初传递给 pack_padded_sequencepack_sequence 时的顺序重新排序。

返回类型

Tuple[Tensor, Tensor]

本页目录