展平
- class torch.nn.Unflatten(dim, unflattened_size) [源代码]
- 
    展平张量的维度,将其扩展为所需的形状。适用于与 Sequential一起使用。- 
      dim指定要展开的输入张量的维度。当使用 Tensor 时,dim是 int 类型;当使用 NamedTensor 时,dim是 str 类型。
- 
      unflattened_size是张量未展平维度的新形状。它可以是一个包含整数的元组、列表或 torch.Size;对于NamedTensor输入,它也可以是一个由(name, size)元组组成的NamedShape。
 - 形状:
- 
      - 
        输入: $(*, S_{\text{dim}}, *)$,其中$S_{\text{dim}}$表示维度 dim的大小,而$*$则表示任意数量的维度,包括零个。
- 
        输出为:$(*, U_1, ..., U_n, *)$,其中 $U$ 等于 unflattened_size,并且 $\prod_{i=1}^n U_i = S_{\text{dim}}$。
 
- 
        
 - 参数
- 
      
- 
        unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 未展平维度的新形状 
 
 示例 >>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) 
-