展平

class torch.nn.Unflatten(dim, unflattened_size) [源代码]

展平张量的维度,将其扩展为所需的形状。适用于与Sequential一起使用。

  • dim 指定要展开的输入张量的维度。当使用 Tensor 时,dimint 类型;当使用 NamedTensor 时,dimstr 类型。

  • unflattened_size 是张量未展平维度的新形状。它可以是一个包含整数的元组、列表或 torch.Size;对于NamedTensor输入,它也可以是一个由(name, size)元组组成的NamedShape

形状:
  • 输入: $(*, S_{\text{dim}}, *)$,其中$S_{\text{dim}}$表示维度dim的大小,而$*$则表示任意数量的维度,包括零个。

  • 输出为:$(*, U_1, ..., U_n, *)$,其中 $U$ 等于 unflattened_size,并且 $\prod_{i=1}^n U_i = S_{\text{dim}}$

参数
  • dim (Union[int, str]) – 需要展开的维度

  • unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 未展平维度的新形状

示例

>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
本页目录