torch.split
- torch.split(tensor, split_size_or_sections, dim=0)[源代码]
-
将张量分成若干部分。每一部分都是原张量的一个视图。
如果
split_size_or_sections
是整数类型,tensor
将被分割成大小相等的块(如果可能的话)。 如果张量在给定维度dim
上的尺寸不能被split_size
整除,那么最后一个块会较小。如果
split_size_or_sections
是一个列表,那么tensor
将在dim
维度上根据split_size_or_sections
的大小被分割成len(split_size_or_sections)
个片段。- 参数
- 返回类型
示例:
>>> a = torch.arange(10).reshape(5, 2) >>> a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) >>> torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) >>> torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))