展平

torch.nn.Flatten(start_dim=1, end_dim=-1)[源代码]

将连续的维度范围展平为一个张量。

Sequential一起使用时,请参见torch.flatten()的详情。

形状:
  • 输入: $(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)$,其中$S_{i}$表示维度i的大小,$*$表示任意数量的维度(包括零个)。

  • 输出为: $(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)$.

参数
  • start_dim (int) – 开始展平的起始维度(默认为 1)。

  • end_dim (int) – 需要展平的最后一个维度(默认为 -1)。

示例:
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
本页目录