torch.unflatten

torch.unflatten(input, dim, sizes) Tensor

将输入张量的一个维度扩展为多个维度。

参见

torch.flatten() 是该函数的逆操作,它将多个维度合并成一个维度。

参数
  • input (Tensor) – 需要输入的张量。

  • dim (int) – 要取消展平的维度,指定为 input.shape 中的一个索引。

  • sizes (Tuple[int]) – 新的未展平维度的形状。其中一项可以是 -1,在这种情况下,相应的输出维度将被自动推断出来。否则,sizes 的乘积必须等于 input.shape[dim]

返回值

具有指定维度的未展平输入视图。

示例:
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])
本页目录