torch.split

torch.split(tensor, split_size_or_sections, dim=0)[源代码]

将张量分成若干部分。每一部分都是原张量的一个视图。

如果 split_size_or_sections 是整数类型,tensor 将被分割成大小相等的块(如果可能的话)。 如果张量在给定维度 dim 上的尺寸不能被 split_size 整除,那么最后一个块会较小。

如果 split_size_or_sections 是一个列表,那么 tensor 将在 dim 维度上根据 split_size_or_sections 的大小被分割成 len(split_size_or_sections) 个片段。

参数
  • tensor (Tensor) - 需要拆分的张量。

  • split_size_or_sections (int) 或 (list(int)) – 单个块的大小或每个块大小的列表

  • dim (int) - 拆分张量的维度。

返回类型

元组[张量, ...]

示例:

>>> a = torch.arange(10).reshape(5, 2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1, 4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))
本页目录