torch.nn.functional.affine_grid

torch.nn.functional.affine_grid(theta, size, align_corners=None)[源代码]

基于一批仿射矩阵theta,生成二维或三维的流场(采样网格)。

注意

此函数通常与grid_sample() 一起使用,用于构建空间变压器网络

参数
  • theta (Tensor) – 输入的仿射矩阵批次,形状为 ($N \times 2 \times 3$)(用于2D)或 ($N \times 3 \times 4$)(用于3D)

  • size (torch.Size) – 目标输出图像的大小。对于二维(2D)为$N \times C \times H \times W$,对于三维(3D)为$N \times C \times D \times H \times W$。示例:torch.Size((32, 3, 24, 24))

  • align_corners (bool, 可选) – 如果为True,则认为-11分别表示角像素的中心而不是图像角落。有关更完整的描述,请参见grid_sample()。由affine_grid()生成的网格应与相同的align_corners选项设置传递给grid_sample()。默认值: False

返回值

输出张量的大小为 ($N \times H \times W \times 2$)

返回类型

输出(Tensor

警告

align_corners = True 时,网格位置取决于相对于输入图像大小的像素尺寸。因此,在不同分辨率下(即缩放或缩小后),给定相同输入时,由grid_sample() 采样的位置会有所不同。在1.2.0版本之前,默认行为是 align_corners = True。从那时起,默认行为更改为 align_corners = False,以使其与interpolate() 的默认值保持一致。

警告

align_corners = True时,对1D数据进行2D仿射变换或对2D数据进行3D仿射变换(即其中一个空间维度大小为单位尺寸)是未定义的,并且这不是预期的使用场景。当align_corners = False时,这不会成为一个问题。在版本1.2.0及之前,所有沿单位维度的网格点都被任意地认为是在-1处。从版本1.3.0开始,在align_corners = True的情况下,所有沿单位维度的网格点都被认为是在0(输入图像的中心)处。

本页目录