展平
- class torch.nn.Unflatten(dim, unflattened_size) [源代码]
-
展平张量的维度,将其扩展为所需的形状。适用于与
Sequential
一起使用。-
dim
指定要展开的输入张量的维度。当使用 Tensor 时,dim
是 int 类型;当使用 NamedTensor 时,dim
是 str 类型。 -
unflattened_size
是张量未展平维度的新形状。它可以是一个包含整数的元组、列表或 torch.Size;对于NamedTensor输入,它也可以是一个由(name, size)元组组成的NamedShape。
- 形状:
-
-
输入: $(*, S_{\text{dim}}, *)$,其中$S_{\text{dim}}$表示维度
dim
的大小,而$*$则表示任意数量的维度,包括零个。 -
输出为:$(*, U_1, ..., U_n, *)$,其中 $U$ 等于
unflattened_size
,并且 $\prod_{i=1}^n U_i = S_{\text{dim}}$。
-
- 参数
-
-
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 未展平维度的新形状
示例
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])
-