伪张量
代码: fake_tensor.py
动机
在进行Dynamo符号评估和编译器传递时,我们希望能够在不实际运行操作或破坏现有张量的情况下,了解张量运算的输出大小、数据类型和设备等信息。实际上执行操作会更慢(如果你进行了大量计算),并且需要占用大量的内存(在你编译程序的时候使用GPU内存是不好的)。假张量在各个方面都像一个真实的张量,除了它没有实际的数据。例如,在进行Dynamo跟踪时,我们需要通过用户的Tensor代码来回答关于中间结果的问题(比如用户在一个中间张量上执行了一个条件判断)。如果没有假张量,我们将无法准确地回答这些问题。
类似地,假设你想为一个张量存储元数据,比如在FX IR节点上的meta['val']。你可以直接在节点上存储一个假的张量,这样可以为你提供该张量所需的所有元数据,包括一些你可能没有考虑到的细节(例如别名关系)。
总体架构
所有假张量都与一个 FakeTensorMode 相关联。因为假张量主要用于分析真实张量,所以一般的工作流程是:你有一堆真实张量,分配一个 FakeTensorMode,然后使用 from_real_tensor 函数将这些真实张量转换为假张量,并对假张量进行操作。特别是,FakeTensorMode 持久地维护着一个映射表,该表将张量(和存储)映射到相同的存储上。如果你多次假化同一个张量,你会得到同一个假张量;如果你假化两个相互别名的张量,你将会得到两个别名为同一假存储的假张量。FakeTensors 是张量子类,因此在它们上面进行操作时会自动获得一个假张量,但在一般情况下,当你对假张量(例如,如果你正在运行 FX pass)进行操作时,希望在激活 FakeTensorMode 的情况下进行;张量操作将自动打开假张量模式并重新尝试。
伪张量表示为__torch_dispatch__ 张量的元张量子类。这意味着在内部,伪张量实际上是元设备张量;它们使用额外的可扩展钩子(特别是dispatch_device)来谎报张量的实际设备信息。这是早期伪张量中比较容易出错的部分之一:有时,伪张量太善于伪装成CPU或CUDA等设备了,并且最终会导致一个CPU内核被调用时试图通过数据指针访问伪张量的数据,这显然是不会工作的。如果你在伪张量代码中遇到段错误(segfault),首先应该检查的是C++ 调试回溯是在一个 CPU 内核(意外!)还是一个元内核(预期!)中?元内核就像真正的内核一样,但它唯一的作用是分配输出,不执行任何数据计算。
张量子类需要定义如何实现各种操作。这里是一些通用的假张量实现方法:
-
在输入的假张量上运行元内核,并将它们重新解释为元张量。这是通过一个神奇的上下文管理器 in_kernel_invocation_manager 实现的,它指示 PyTorch 将假张量视为其底层的元张量,而不是“解包”假张量为元张量(假张量就是元张量)。以这种方式表示假张量是为了避免需要同步两组元数据(即元张 tensor 的元数据和假张量的元数据);这种“是”的关系确保只有一个规范的元数据副本。
-
如果你是工厂函数,将会改用设备参数为“meta”来调用底层的工厂函数。
-
将生成的元张量转换为假张量,并确定其输出设备应该是哪个(这通常很简单,但在某些情况下会变得复杂,比如 CPU 标量提升或设备转换操作。)
API: 关键内容
非PT2用法(更多示例如下:test/test_fake_tensor.py):
# Create a fake mode from torch._subclasses.fake_tensor import FakeTensorMode fake_mode = FakeTensorMode() converter = fake_mode.fake_tensor_converter # Fakeify some real tensors fake_x = converter.from_real_tensor(fake_mode, x) with fake_mode: # Do some operations on the fake tensors fake_y = fake_x * 2 # Factory operations automatically get fakeified in the context manager fake_z = torch.empty(20)
问题:为什么你的输入是实数张量?
A: 在PT2上下文中,这是因为通常会进行即时编译。因此,在执行程序的同时编译图的每个输入时,就已经有了这些输入的实际值。
PT2 预AOT Autograd 的使用(这种情况比较特殊,通常不建议这样做):
# Fake mode is not enabled! from torch._guards import detect_fake_mode fake_mode = detect_fake_mode(args) # if fake_mode isn't None converter = fake_mode.fake_tensor_converter fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args] with fake_mode: ... do stuff with the fake args, if needed ...
detect_fake_mode将在多个位置搜索,以尝试找到与生命周期相关的“唯一”的假张量模式。通常它会从追踪上下文中获取。
使用 AOTAutograd 后的 PT2 用法:
# 仿冒模式已启用!example_inputs 通常是假数据 # TODO:我们需要考虑修改这一点 # 但仍需这样做以进入仿冒模式 fake_mode = detect_fake_mode(example_inputs) # 一般情况下,你无需手动开启它
其他有用的资料:
from torch._subclasses.fake_tensor import unset_fake_temporarily with unset_fake_temporarily(): # fake mode is disabled here, you can do real tensor compute
你可能会在什么情况下想要禁用假张量模式?通常不需要这么做。我们发现的一个特殊情况是,在假张量上实现常量传播时:在这种情况下,尽管处于假张量模式,但仍需执行实际的张量计算。
FakeTensorProp from torch.fx.passes.fake_tensor_prop gm: GraphModule real_inputs: List[Tensor] FakeTensorProp(gm).propagate(*real_inputs) # This will populate meta['val'] on all the FX nodes with a fake tensor # or if you have a preexisting fake mode, you should use it FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs) # There is also propagate_dont_convert_inputs if your inputs are already fake fake_inputs: List[FakeTensor] FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)
详情
自动转换还是不转换?最初,FakeTensorMode 并不会在该模式区域内自动将真实张量转换为假张量。这样做的原因是避免出现以下错误操作:
with FakeTensorMode(): real_tensor.t_()
这段代码应该做什么?如果我们实际上修改了真实张量的元数据,这会令人惊讶。但同时,并没有明显的时机来创建一个FakeTensor。因此我们保守地决定让这段代码抛出一个错误:“在FakeTensorMode中使用非Fake Tensor输入调用操作符尚未支持,请先将所有Tensors转换为FakeTensors。”
实际上这个问题很烦人。比如,如果你有一个真实的nn.Module,并想用虚假的张量进行测试,你就需要把nn.Module也变得“虚拟”。这就催生了FakeCopyMode。
最终,我们选择了自动添加假数据生成功能。然而,在许多使用FakeTensorMode的场景中,默认情况下这项功能仍然未被启用。
假张量的元数据修改:如果你有一个假张量,并对其进行 t_() 操作,该假张量的元数据会发生变化。虽然这种情况表面上看起来合理,但有时你可能希望将假张量作为 FX 节点上的元数据进行存储。然而,这样做会使得旧的元数据失效,因此是不好的。
实际上,这里存在一个根本性的矛盾:假张量会维护关于张量极其准确的元数据,包括对象身份信息。如果 FX 图中的对象元数据随时间发生变化,则无法表示这种变化。大多数情况下,我们的严肃 FX 分析都是在功能化图上进行的,这些图不包含这种情况,但偶尔需要对非功能化图进行分析。也许将假张量放入 meta[‘val’] 中是一个错误。
关于张量子类
Fake 张量同时使用了子类和模式张量子类的模式。具体来说,FakeTensor.__torch_dispatch__ 启用了与 Fake 张量关联的 FakeTensorMode,并重新分派(依赖于 FakeTensorMode 来完成主要工作)。如果假张量操作接收到它不认识的子类参数,则会返回 NotImplemented,给其他子类一个先运行的机会(希望其能简化为普通的张量操作),然后再尝试一次。这可能会导致无限循环。
每个运算符是如何实现的?
不幸的是,任何给定的操作可能在多个复杂的地点实现。以下是一些重要的情况:
-
当张量的元素数量非常少时,张量子类可以进行有限的常量传播(这有助于处理我们立即对这些张量调用item()的情况)。
-
为了提高性能,我们为某些操作实现了一些快路径版本,这些版本完全使用假张量来完成。
-
如果你使用@custom_op生成自定义张量,它们将直接注册为伪张量的impl_abstract。
-
Fake tensor 本身为设备转换操作预设了一些特殊情形。
-
如果没有元实现或任何分解,我们将生成真实的零填充张量,并尝试直接运行操作符以查看其结果。然而,如果操作符尝试使用数据进行索引访问,则这可能会导致段错误。因此,默认情况下不会为自定义操作符启用此功能。
转换器的工作原理是什么?
由于假张量在对张量的具体属性非常敏感的场景下使用,因此它们会非常小心地进行转换,以保持叶子节点状态、requires_grad 属性、别名关系等一系列特性。大部分复杂的工作都是由 MetaConverter 完成的。
性能特点
你可能会认为假张量很快,因为它们不进行任何张量计算。但实际上,在小尺寸的张量下,我们完全受限于开销,并且假张量是在 Python 中实现的。对于单个张量操作,我们需要做大量的工作(因为这些操作是通过分解来实现的)。因此,实际上,特别是在涉及符号形状时,假张量运行得很慢。目前我们在假张量中有两个重要的快速路径,在实际应用中可以产生显著的效果:
-
点操作不经过 PrimTorch 分解,而是手动编写了它们的传播规则。
-
如果我们能够的话,我们就应该这么做。
假张量的副本?
有人希望将假张量作为用户输入传入PT2堆栈,这意味着我们需要能够创建一个表示假张量的假张量。虽然当前还不支持这一点,但实现起来可能不会太困难。
动态形状的交互
每个FakeTensorMode都包含一个ShapeEnv,用于跟踪所有的符号形状信息。它们的生命周期通常是一起开始和结束的。
由于FakeTensorMode具有ShapeEnv(而元实现没有),因此依赖于数据并需要分配未绑定的SymInt的元函数会存在于假张量中。假张量还负责缓存这些未绑定的SymInt,例如,如果你对同一个假张量调用两次nonzero()方法,你会得到相同的符号大小。