torch.vsplit

torch.vsplit(input, indices_or_sections) Tensors 列表

input(一个具有两个或更多维度的张量)根据indices_or_sections垂直分割成多个张量。每个分割结果都是input的一个视图。

这相当于调用 torch.tensor_split(input, indices_or_sections, dim=0)(其中分割维度为 0),除了当 indices_or_sections 是一个整数时,必须能整除该维度,否则会抛出运行时错误。

此函数基于 NumPy 的 numpy.vsplit()

参数
示例:
>>> t = torch.arange(16.0).reshape(4,4)
>>> t
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]])
>>> torch.vsplit(t, 2)
(tensor([[0., 1., 2., 3.],
         [4., 5., 6., 7.]]),
 tensor([[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]))
>>> torch.vsplit(t, [3, 6])
(tensor([[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]]),
 tensor([[12., 13., 14., 15.]]),
 tensor([], size=(0, 4)))
本页目录