torch.set_float32_matmul_precision

torch.set_float32_matmul_precision(precision)[源代码]

设置 float32 矩阵乘法的内部精度。

以较低精度执行 float32 矩阵乘法可能会显著提升性能,在某些程序中,精度的损失几乎可以忽略。

支持三种不同的设置:

  • 在“最高”精度下,浮点乘法运算使用 float32 数据类型进行内部计算,该类型包含 24 位尾数,其中 23 位是显式存储的。

  • 在“high”模式下,浮点数(float32)矩阵乘法将使用TensorFloat32数据类型(显式存储10位尾数),或者将每个float32数字视为两个bfloat16数字的和(大约16位尾数,其中14位被显式存储)。前提是相应的快速矩阵乘法算法可用。否则,float32矩阵乘法则会以“最高”精度进行计算。有关bfloat16方法的更多信息,请参见下文。

  • 对于“medium”精度,如果存在内部使用 bfloat16 数据类型的快速矩阵乘法算法(bfloat16 有 8 个尾数位,其中 7 位显式存储),则 float32 矩阵乘法在内部计算时会使用该数据类型。否则,float32 矩阵乘法则按“high”精度进行计算。

当使用“高”精度时,float32的乘法可能会采用基于bfloat16的算法,这种算法比简单地截断有效位数更复杂(例如TensorFloat32为10位,bfloat16显式存储为7位)。有关此算法的完整描述,请参阅[Henry2019]。简要说明一下,在第一步中,我们发现可以将单个float32数字表示为三个bfloat16数字之和(因为float32有23位有效位,而bfloat16显式存储7位,并且两者具有相同数量的指数位)。这意味着两个float32数字的乘积可以通过九个bfloat16数字的乘积之和精确给出。通过舍弃一些这些乘积来换取速度,“高”精度算法特别只保留三个最重要的乘积,这排除了所有涉及任一输入最后8位有效位的所有乘积。这意味着我们可以将我们的输入表示为两个bfloat16数字的总和而不是三个。由于bfloat16融合乘加(FMA)指令通常比float32快10倍以上,因此使用bfloat16精度进行三次乘法和两次加法运算比使用float32精度进行一次乘法运算更快。

Henry2019

点击这里查看论文

注意

这不会改变浮点数(float32)矩阵乘法的输出数据类型,而是控制矩阵乘法内部计算的过程。

注意

这不会影响卷积操作的精度。然而,其他标志,如torch.backends.cudnn.allow_tf32,可以控制卷积操作的精度。

注意

此标志目前仅影响一种原生设备类型:CUDA。如果将其设置为“high”或“medium”,则在进行float32矩阵乘法计算时会使用TensorFloat32数据类型,相当于设置了torch.backends.cuda.matmul.allow_tf32 = True。若将其设置为“highest”(默认值),则内部计算将使用float32数据类型,相当于设置了torch.backends.cuda.matmul.allow_tf32 = False

参数

precision (str) – 可以设置为“highest”(默认值)、“high”或“medium”。(具体参见上述说明)

本页目录