torch.squeeze
- torch.squeeze(input, dim=None) → Tensor
-
返回一个张量,移除输入张量中所有大小为1的维度。
例如,如果 input 的形状为:$(A \times 1 \times B \times C \times 1 \times D)$,那么执行 input.squeeze() 后的形状将会是:$(A \times B \times C \times D)$。
当给定
dim
时,挤压操作仅在指定的维度上进行。如果输入的形状为:$(A \times 1 \times B)$,则squeeze(input, 0)
不会改变张量,但squeeze(input, 1)
会将张量挤压为形状$(A \times B)$。注意
返回的张量与输入张量共用同一存储空间,所以修改其中一个张量的数据也会改变另一个张量的数据。
警告
如果张量的批次维度大小为1,squeeze(input)会移除该批次维度,可能导致意外错误。建议只指定你想压缩的具体维度。
- 参数
示例:
>>> x = torch.zeros(2, 1, 2, 1, 2) >>> x.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x) >>> y.size() torch.Size([2, 2, 2]) >>> y = torch.squeeze(x, 0) >>> y.size() torch.Size([2, 1, 2, 1, 2]) >>> y = torch.squeeze(x, 1) >>> y.size() torch.Size([2, 2, 1, 2]) >>> y = torch.squeeze(x, (1, 2, 3)) torch.Size([2, 2, 2])