展平
- 类torch.nn.Flatten(start_dim=1, end_dim=-1)[源代码]
-
将连续的维度范围展平为一个张量。
与
Sequential
一起使用时,请参见torch.flatten()
的详情。- 形状:
-
-
输入: $(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)$,其中$S_{i}$表示维度i的大小,$*$表示任意数量的维度(包括零个)。
-
输出为: $(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)$.
-
- 示例:
-
>>> input = torch.randn(32, 1, 5, 5) >>> # With default parameters >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # With non-default parameters >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5])