torch.cat

torch.cat(tensors, dim=0, *, out=None)Tensor

沿指定维度连接给定的seq张量序列。所有张量的形状必须相同(连接维度除外),或者是一个大小为(0,)的一维空张量。

torch.cat() 可以看作是 torch.split()torch.chunk() 的逆操作。

torch.cat() 最好通过例子来理解。

参见

torch.stack() 沿着新的维度将给定的序列进行连接。

参数
  • tensors (序列 of Tensor) – 任何相同类型的张量组成的 Python 序列。提供的非空张量必须具有相同的形状,除了在拼接的维度上。

  • dim (int, optional) – 要拼接的张量的维度

关键字参数

out (Tensor, 可选) – 指定输出张量。

示例:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])
本页目录