动态形状

代码: symbolic_shapes.py

参见:《动态形状手册》

动机

深度学习编译器通常只支持静态形状,也就是说,生成的编译程序仅适用于特定的输入形状配置,并且如果输入形状发生变化,则需要重新编译。这种假设对于大多数常用的深度学习模型来说非常有效,但在某些情况下是不够的:

  • 某些维度(如批量大小或序列长度)可能会发生变化。例如,执行自适应批处理的推理服务会根据在批处理窗口内接收到的请求数量,以不同的批量大小来执行推理请求。我们还可能希望仅将可变长度的序列填充到当前批次中的最大序列长度,而这个长度可能会从一批次变化到另一批次。

  • 某些模型的输出形状会根据输入数据的变化而变化,也就是说,它们的输出和中间结果大小可能取决于实际输入的数据,并且这些数据在不同运行中可能会有所不同。例如,检测模型可能会先生成一个可变数量的潜在边界框,然后再运行一个更昂贵的图像识别模型来确定目标是否在一个边界框内。边界框的数量是依赖于数据的。

  • 当处理稀疏表示(如稀疏张量、不规则张量和图神经网络)时,会出现一个特别重要的数据依赖形状的情况。在这种情况下,要处理的数据量取决于问题的稀疏结构,并且这种结构通常会以数据依赖的方式变化。

在支持动态形状时,我们选择不支持动态秩程序(即输入张量维度发生变化的程序),因为这种情况在实际深度学习程序中很少见,并且可以避免对形状符号列表进行归纳推理的需求。

简化后的公共API

PyTorch 2.1 的默认行为是:

  • PT2 默认认为一切都是静态的

  • 如果我们因为大小变化而重新编译,我们将尝试将该大小视为动态大小进行处理(已经发生变化的大小在未来可能会继续变化)。然而,这种泛化可能会失败(例如,由于用户代码在相关大小上进行了条件分支或 PT2 中缺少对动态形状的支持)。如果你想知道为什么 PT2 对某些代码过度专业化,请使用 TORCH_LOGS=dynamic 运行,并查找“eval”条目中显示何时添加了保护以及原因的日志。

  • 如果你提前知道某些内容将是动态的,可以使用torch._dynamo.mark_dynamic(tensor, dim)跳过第一次重新编译。如果提前知道了这个维度可能取到的最小值min和最大值max,你可以指定torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)

  • 如果你使用 torch.compile(dynamic=False),我们将关闭在重新编译时的自动动态形状,并为每个不同的大小进行单独编译。相反,如果你使用 torch.compile(dynamic=True),我们将尽量使一切尽可能地动态化。这主要适用于小型操作符;如果在一个大型模型上尝试这样做,它会(1)很可能导致PT2崩溃,以及(2)毫无理由地运行得非常慢。

守卫模型

在为TorchDynamo和TorchInductor添加对动态形状支持的设计过程中,我们做出了一项重要决策:为了重用针对PyTorch API编写的Python/C++代码(包括分解等现有代码),我们需要能够通过动态形状进行追踪。与完全符号化的系统不同,后者会捕获条件语句的两个分支,我们只选择其中一个分支,并假设将来在相同条件下仍会选择该分支来进行专门化追踪。为此,我们为每个符号大小维护一个“提示”,说明其在编译时的具体值(由于TorchDynamo是一个即时编译器,它始终知道实际输入尺寸)。当我们在张量上执行条件判断时,我们会查看这个提示来决定选择哪个分支。

这大大简化了我们生成的符号形状公式,但需要一个更复杂的系统来管理守卫。例如,考虑以下程序:

def f(x, y):
    z = torch.cat([x, y])
    if z.size(0) > 2:
        return z.mul(2)
    else:
        return z.add(2)

我们将使用 TorchInductor 编译的最终 IR 将是 torch.cat([x, y]).add(2) 或者 torch.cat([x, y]).mul(2)(条件被扁平化)。为了确定我们处于哪个分支,我们需要知道中间变量 z 的大小。由于 TorchDynamo 必须提前判断编译的跟踪是否有效(我们不支持像某些 JIT 编译器那样的退出机制),我们必须能够将 z.size(0) 表达为输入张量大小之和的形式,即 x.size(0) + y.size(0)。这是通过为 PyTorch 中的所有操作符编写元函数来实现的,这些元函数可以在不实际执行计算的情况下将大小信息传播到张量的输出。

总体架构

符号形状_workflow:

  1. 当我们在Dynamo中开始编译一个帧时,会分配一个与FakeTensorMode关联的ShapeEnv来跟踪符号形状的状态。

  2. 我们在入口为张量分配符号大小(静态还是动态由策略决定,可以进行调整)。

  3. 我们将符号大小通过操作符进行传播,并同时维护两部分内容:(1) FX IR,以确保能够忠实导出符号计算;(2) 表示大小变量的 Sympy 表达式,以便我们可以对它们进行推理。

  4. 当我们在 Dynamo 跟踪或 Inductor 优化中基于符号大小进行条件判断时,会根据条件添加相应的保护措施。这些保护措施既可以通过 Python 引入,也可以通过 C++ 引入。

  5. 这些守卫可以进一步简化符号变量。例如,如果你断言s0 == 4,我们现在可以将所有出现的s0 替换为 4

  6. 完成追踪和优化后,我们将所有守卫与编译代码一同安装。只有在所有守卫均为真的情况下,编译代码才可被重复使用。

重要文件:

  • C++ SymInt API:c10/core/SymInt.hSymFloat.hSymBool.h

  • Python SymInt API: torch/__init__.py(请查找 SymInt/SymFloat/SymBool

  • C++ 实现细节: c10/core/SymNodeImpl.h, torch/csrc/utils/python_symnode.h, torch/csrc/jit/python/init.cpp

  • Python基础设施: torch/fx/experimental/symbolic_shapes.py

  • 其他重要的文件有:torch/_subclasses/fake_tensor.pytorch/_meta_registrations.py,以及 decomps 和 PrimTorch 参考。

简化版内部API

理解Python的类层级结构:

  • SymInt/SymFloat/SymBool:这些是用户可见的类,用于模拟对应的 int、float 和 bool 类型。当你将两个 SymInt 相加时,系统会返回一个新的 SymInt,它符号化地记录了整数加法已经发生。

  • SymNode:这是内部结构(可通过例如 symint.node 访问),用于保存实际的符号跟踪信息。SymNode 是类型擦除的,这使得表示混合类型的操作更加方便。需要注意的是,从技术上讲,你不必从 SymInt 调用 Python 的 SymNode;例如,XLA 的 C++ SymNodeImpl 将会替代 SymNode。

  • ShapeEnv:编译时的上下文状态,用于跟踪迄今为止累积的所有自由符号和守卫。每个SymNode都会记录它的ShapeEnv(反之则不成立;只有当SymNode参与到某个守卫中时才会被使用)。

C++也十分相似:

  • c10::SymInt/SymFloat/SymBool:用户可见的类,用于模拟int/float/bool。

  • c10::SymNode/SymNodeImpl:类似于 SymNode

  • C++ 中没有 ShapeEnv;为了便于调试,整个符号推理机制都是用 Python 实现的。

当你编写可以使用 make_fx 追踪的代码时,它必须能够处理 SymInt、SymFloat 和 SymBool。你可以参考动态形状手册 以获取更多关于如何实现这一点的指导。

DimDynamic 策略

符号推理:

  • 值范围

  • Sympy 使用说明

  • 约束条件

  • DimDynamic/Constraint

未支持的 SymInts

为了解析控制流,我们检查符号整数的实际值(即提示),以确定要采取哪个分支。然而,在某些情况下,我们可能没有实际值:当大小变量从依赖于数据的操作中出现时,如.nonzero().item(),就会产生所谓的未被支持的符号整数。在这些符号整数上执行控制流是非法的,因此我们必须在这类操作上进行图中断。

简单地实现这一点会过于严格:如果你尝试使用未被支持的符号整数进行任何操作,大多数 PyTorch 程序会立即失败。以下是一些最重要的改进,以使其真正可行:

  • 在创建张量时,PyTorch 会预先计算许多关于该张量的信息;例如,如果你使用 empty_strided 来创建一个张量,系统会急切地对步长进行排序,并判断张量是否是非重叠且密集的。这种操作会产生大量的保护措施。然而,更常见的情况是直接使用高级 API 如 empty 创建张量,这可以保证生成非重叠且密集的张量。我们修改了 PyTorch 以避免不必要的重新计算这些属性。

  • 即使需要进行复杂的计算,有时某个属性实际上从不会被查询到。通过将这些预计算的属性改为惰性计算,我们可以在真正需要时才对其进行处理,从而避免不必要的资源消耗和潜在的问题。

  • 整数张量中的数据通常是不确定是否为非负值的。但是,我们提供了一个 API constrain_range,允许用户指定某个大小在一个已知范围内的上限和下限。

在未来版本的PT2(高于PT2.1)中,我们将扩展推理系统,根据使用情况推断未支持的符号整数是大小相关的。例如,如果你将.item()调用的结果传递给像torch.empty这样的工厂函数,我们将自动推断该结果是一个尺寸(否则将会失败)。此假设将在运行时进行验证,并在不满足条件时引发错误。

本页目录