torch.tensor_split
- torch.tensor_split(input, indices_or_sections, dim=0) → Tensors 列表
 - 
    
沿维度
dim根据indices_or_sections指定的索引或部分数量,将张量分割为多个子张量。所有子张量都是input的视图。此函数基于 NumPy 的numpy.array_split()。- 参数
 - 
      
- 
        
input (Tensor) – 需要进行拆分的张量
 - 
        
indices_or_sections (Tensor, int 或 list 或 tuple of ints) –
如果
indices_or_sections是一个整数n或一个值为n的零维长张量,那么input将在维度dim上被分割成n个部分。如果input在维度dim上可以被n整除,则每个部分的大小相等,为input.size(dim) / n。如果不能整除,则前int(input.size(dim) % n)个部分的大小为int(input.size(dim) / n) + 1,其余部分的大小为int(input.size(dim) / n)。如果
indices_or_sections是一个整数列表或元组,或者是一个一维的长张量,则会根据dim维度在这些索引处将input切分。例如,indices_or_sections=[2, 3]和dim=0会生成张量input[:2],input[2:3],和input[3:]。如果
indices_or_sections是一个张量,它必须是位于 CPU 上的零维或一维长整型张量。 - 
        
dim (int, 可选) – 拆分张量的维度。默认值:
0 
 - 
        
 
示例:
>>> x = torch.arange(8) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7])) >>> x = torch.arange(7) >>> torch.tensor_split(x, 3) (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) >>> torch.tensor_split(x, (1, 6)) (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6])) >>> x = torch.arange(14).reshape(2, 7) >>> x tensor([[ 0, 1, 2, 3, 4, 5, 6], [ 7, 8, 9, 10, 11, 12, 13]]) >>> torch.tensor_split(x, 3, dim=1) (tensor([[0, 1, 2], [7, 8, 9]]), tensor([[ 3, 4], [10, 11]]), tensor([[ 5, 6], [12, 13]])) >>> torch.tensor_split(x, (1, 6), dim=1) (tensor([[0], [7]]), tensor([[ 1, 2, 3, 4, 5], [ 8, 9, 10, 11, 12]]), tensor([[ 6], [13]]))