MaxUnpool2d

torch.nn.MaxUnpool2d(kernel_size, stride=None, padding=0)[源代码]

计算MaxPool2d的部分逆运算。

MaxPool2d 不是完全可逆的,因为非最大值会被丢弃。

MaxUnpool2dMaxPool2d 的输出(包括最大值的索引)作为输入,并计算一个部分逆,其中所有非最大值都被设置为零。

注意

当输入索引包含重复值时,此操作可能会表现出非确定性的行为。详情请参见https://github.com/pytorch/pytorch/issues/80827重复性

注意

MaxPool2d 可以将多个输入大小映射到相同的输出大小,从而导致反向过程变得模糊不清。为了解决这个问题,你可以在前向调用中提供所需的输出大小作为参数 output_size。请参见下面的示例。

参数
  • kernel_size (inttuple) – 最大池化窗口的尺寸。

  • stride (inttuple) – 最大池化窗口的步长,默认值为 kernel_size

  • padding (inttuple) – 输入数据的填充值

输入:
  • input: 输入张量,需要对其进行逆运算

  • indices: MaxPool2d给出的索引

  • output_size(可选):目标输出的大小

形状:
  • 输入: $(N, C, H_{in}, W_{in})$$(C, H_{in}, W_{in})$

  • 输出为 $(N, C, H_{out}, W_{out})$$(C, H_{out}, W_{out})$

    $H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}$
    $W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}$

    或者如调用操作符中的 output_size 所给定的

示例:

>>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
>>> unpool = nn.MaxUnpool2d(2, stride=2)
>>> input = torch.tensor([[[[ 1.,  2.,  3.,  4.],
                            [ 5.,  6.,  7.,  8.],
                            [ 9., 10., 11., 12.],
                            [13., 14., 15., 16.]]]])
>>> output, indices = pool(input)
>>> unpool(output, indices)
tensor([[[[  0.,   0.,   0.,   0.],
          [  0.,   6.,   0.,   8.],
          [  0.,   0.,   0.,   0.],
          [  0.,  14.,   0.,  16.]]]])
>>> # Now using output_size to resolve an ambiguous size for the inverse
>>> input = torch.tensor([[[[ 1.,  2.,  3.,  4.,  5.],
                            [ 6.,  7.,  8.,  9., 10.],
                            [11., 12., 13., 14., 15.],
                            [16., 17., 18., 19., 20.]]]])
>>> output, indices = pool(input)
>>> # This call will not work without specifying output_size
>>> unpool(output, indices, output_size=input.size())
tensor([[[[ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  7.,  0.,  9.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0., 17.,  0., 19.,  0.]]]])
本页目录