ChannelShuffle

class torch.nn.ChannelShuffle(groups) [源代码]

将张量中的通道划分并重新排列。

此操作将形状为$(N, C, *)$的张量中的通道分成g组,每组包含$\frac{C}{g}$个通道,并进行洗牌处理。最终输出保持原始张量的形状不变。

参数

groups (int) – 将通道划分成的组的数量。

示例:

>>> channel_shuffle = nn.ChannelShuffle(2)
>>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
>>> input
tensor([[[[ 1.,  2.],
          [ 3.,  4.]],
         [[ 5.,  6.],
          [ 7.,  8.]],
         [[ 9., 10.],
          [11., 12.]],
         [[13., 14.],
          [15., 16.]]]])
>>> output = channel_shuffle(input)
>>> output
tensor([[[[ 1.,  2.],
          [ 3.,  4.]],
         [[ 9., 10.],
          [11., 12.]],
         [[ 5.,  6.],
          [ 7.,  8.]],
         [[13., 14.],
          [15., 16.]]]])
本页目录