折叠
- classtorch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)[源代码]
-
将一系列滑动局部块合并成一个大的包含张量。
考虑一个批次的
input
张量,其中包含滑动局部块(例如图像补丁),形状为$(N, C \times \prod(\text{kernel\_size}), L)$。这里,$N$是批次维度,"$C \times \prod(\text{kernel\_size})$" 表示每个块内的值数量(每个块有$\prod(\text{kernel\_size})$个空间位置,每个位置包含一个$C$-通道向量),$L$是块的总数。这个形状与Unfold
输出张量的形状相同。此操作通过将重叠值相加,将这些局部块组合成一个大张量output
,其形状为$(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)$。与Unfold
类似,参数必须满足以下条件:$L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] % - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,$其中 $d$ 表示所有空间维度。
-
output_size
描述了滑动局部块的大容器张量的空间形状。当多个输入形状映射到相同数量的滑动块时(例如,stride > 0
),它有助于解决这种歧义。
padding
、stride
和dilation
参数定义了如何获取滑动块。-
stride
控制滑动块的步长。 -
padding
控制每个维度在重塑之前每侧隐式填充的零的点数。 -
dilation
控制核点之间的间距,也称为 à trous 算法。它更难描述,但这个 链接 提供了一个很好的可视化来展示dilation
的作用。
- 参数
-
如果
output_size
、kernel_size
、dilation
、padding
或stride
是一个整数或长度为 1 的元组,则它们的值会在所有空间维度上重复。 -
当有两个输出空间维度时,此操作有时被称为
col2im
。
注意
Fold
通过将所有包含块中的值相加来计算结果大张量中每个组合的值。Unfold
则从大张量中复制值以提取局部块中的数据。因此,如果这些块有重叠,则它们不是彼此的逆操作。一般来说,折叠和展开操作之间的关系如下。考虑使用相同参数创建的
Fold
和Unfold
实例:>>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...) >>> fold = nn.Fold(output_size=..., **fold_params) >>> unfold = nn.Unfold(**fold_params)
对于任何受支持的
input
张量,下列等式成立:fold(unfold(input)) == divisor * input
其中
divisor
是一个张量,仅依赖于input
的形状和数据类型:>>> input_ones = torch.ones(input.shape, dtype=input.dtype) >>> divisor = fold(unfold(input_ones))
当
divisor
张量中不含零元素时,fold
和unfold
操作互为逆运算(不考虑常数除数)。警告
目前仅支持未批处理的(3D)和批处理的(4D)图像-like输出张量。
- 形状:
-
-
输入: $(N, C \times \prod(\text{kernel\_size}), L)$ 或者 $(C \times \prod(\text{kernel\_size}), L)$
-
输出格式如下:$(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)$ 或 $(C, \text{output\_size}[0], \text{output\_size}[1], \dots)$,如上所述。
-
示例:
>>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2)) >>> input = torch.randn(1, 3 * 2 * 2, 12) >>> output = fold(input) >>> output.size() torch.Size([1, 3, 4, 5])
-