torch.stack

torch.stack(tensors, dim=0, *, out=None)Tensor

沿着新的维度合并张量序列。

所有的张量都需要具有相同的尺寸。

参见

torch.cat() 按照现有维度连接给定的序列。

参数
  • tensors (张量序列) – 需要进行拼接的张量序列

  • dim (int, 可选) – 插入的维度。必须在 0 和拼接张量的总维度数(包括)之间。默认值:0

关键字参数

out (Tensor, 可选) – 指定输出张量。

示例:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.3367,  0.1288,  0.2345],
        [ 0.2303, -1.1229, -0.1863]])
>>> torch.stack((x, x)) # same as torch.stack((x, x), dim=0)
tensor([[[ 0.3367,  0.1288,  0.2345],
         [ 0.2303, -1.1229, -0.1863]],

        [[ 0.3367,  0.1288,  0.2345],
         [ 0.2303, -1.1229, -0.1863]]])
>>> torch.stack((x, x)).size()
torch.Size([2, 2, 3])
>>> torch.stack((x, x), dim=1)
tensor([[[ 0.3367,  0.1288,  0.2345],
         [ 0.3367,  0.1288,  0.2345]],

        [[ 0.2303, -1.1229, -0.1863],
         [ 0.2303, -1.1229, -0.1863]]])
>>> torch.stack((x, x), dim=2)
tensor([[[ 0.3367,  0.3367],
         [ 0.1288,  0.1288],
         [ 0.2345,  0.2345]],

        [[ 0.2303,  0.2303],
         [-1.1229, -1.1229],
         [-0.1863, -0.1863]]])
>>> torch.stack((x, x), dim=-1)
tensor([[[ 0.3367,  0.3367],
         [ 0.1288,  0.1288],
         [ 0.2345,  0.2345]],

        [[ 0.2303,  0.2303],
         [-1.1229, -1.1229],
         [-0.1863, -0.1863]]])
本页目录