torch.squeeze

torch.squeeze(input, dim=None)Tensor

返回一个张量,移除输入张量中所有大小为1的维度。

例如,如果 input 的形状为:$(A \times 1 \times B \times C \times 1 \times D)$,那么执行 input.squeeze() 后的形状将会是:$(A \times B \times C \times D)$

当给定dim时,挤压操作仅在指定的维度上进行。如果输入的形状为:$(A \times 1 \times B)$,则squeeze(input, 0)不会改变张量,但squeeze(input, 1)会将张量挤压为形状$(A \times B)$

注意

返回的张量与输入张量共用同一存储空间,所以修改其中一个张量的数据也会改变另一个张量的数据。

警告

如果张量的批次维度大小为1,squeeze(input)会移除该批次维度,可能导致意外错误。建议只指定你想压缩的具体维度。

参数
  • input (Tensor) – 需要输入的张量。

  • dim (int元组 of ints, 可选) –

    如果提供了输入,它将被压缩

    仅在指定的维度内。

    从版本 2.0 开始:dim 现在支持接受维度元组。

示例:

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
>>> y = torch.squeeze(x, (1, 2, 3))
torch.Size([2, 2, 2])
本页目录