ExportDB
ExportDB 是一个集中化的数据集,包含了支持和不支持的导出案例。它面向希望了解哪些类型的代码被支持、导出的具体规则以及如何修改现有代码以使其与导出兼容的用户。请注意,这并不是 ExportDB 支持的所有内容的详尽列表,但它涵盖了用户最常遇到且最容易混淆的使用场景。
如果你有一个特性,并认为它需要我们更强的支持来实现导出功能,请在 pytorch/pytorch 仓库中创建一个问题,并添加 module:export 标签。
标签
支持
assume_constant_result
原始源代码:
# mypy: allow-untyped-defs import torch import torch._dynamo as torchdynamo class AssumeConstantResult(torch.nn.Module): """ Applying `assume_constant_result` decorator to burn make non-tracable code as constant. """ @torchdynamo.assume_constant_result def get_item(self, y): return y.int().item() def forward(self, x, y): return x[: self.get_item(y)] example_args = (torch.randn(3, 2), torch.tensor(4)) tags = {"torch.escape-hatch"} model = AssumeConstantResult() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", y: "i64[]"): slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 4); x = None return (slice_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_1'), target=None)]) Range constraints: {}
autograd_function
注意
标签:
支持级别:SUPPORTED
原始源代码:
# mypy: allow-untyped-defs import torch class MyAutogradFunction(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone() @staticmethod def backward(ctx, grad_output): return grad_output + 1 class AutogradFunction(torch.nn.Module): """ TorchDynamo does not keep track of backward() on autograd functions. We recommend to use `allow_in_graph` to mitigate this problem. """ def forward(self, x): return MyAutogradFunction.apply(x) example_args = (torch.randn(3, 2),) model = AutogradFunction() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): clone: "f32[3, 2]" = torch.ops.aten.clone.default(x); x = None return (clone,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='clone'), target=None)]) Range constraints: {}
class_method
注意
标签:
支持级别:SUPPORTED
原始源代码:
# mypy: allow-untyped-defs import torch class ClassMethod(torch.nn.Module): """ Class methods are inlined during tracing. """ @classmethod def method(cls, x): return x + 1 def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 2) def forward(self, x): x = self.linear(x) return self.method(x) * self.__class__.method(x) * type(self).method(x) example_args = (torch.randn(3, 4),) model = ClassMethod() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_linear_weight: "f32[2, 4]", p_linear_bias: "f32[2]", x: "f32[3, 4]"): linear: "f32[3, 2]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1) add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1) mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add, add_1); add = add_1 = None add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1); linear = None mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(mul, add_2); mul = add_2 = None return (mul_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)]) Range constraints: {}
cond_branch_class_method
原始源代码:
# mypy: allow-untyped-defs import torch from functorch.experimental.control_flow import cond class MySubModule(torch.nn.Module): def foo(self, x): return x.cos() def forward(self, x): return self.foo(x) class CondBranchClassMethod(torch.nn.Module): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using class method in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def __init__(self) -> None: super().__init__() self.subm = MySubModule() def bar(self, x): return x.sin() def forward(self, x): return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) example_args = (torch.randn(3),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondBranchClassMethod() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3]"): sin: "f32[3]" = torch.ops.aten.sin.default(x); x = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)]) Range constraints: {}
cond_branch_nested_function
原始源代码:
# mypy: allow-untyped-defs import torch from functorch.experimental.control_flow import cond class CondBranchNestedFunction(torch.nn.Module): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using nested function in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def forward(self, x): def true_fn(x): def inner_true_fn(y): return x + y return inner_true_fn(x) def false_fn(x): def inner_false_fn(y): return x - y return inner_false_fn(x) return cond(x.shape[0] < 10, true_fn, false_fn, [x]) example_args = (torch.randn(3),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondBranchNestedFunction() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3]"): add: "f32[3]" = torch.ops.aten.add.Tensor(x, x); x = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
cond_branch_nonlocal_variables
原始源代码:
# mypy: allow-untyped-defs import torch from functorch.experimental.control_flow import cond class CondBranchNonlocalVariables(torch.nn.Module): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. The code below will not work because capturing closure variables is not supported. ``` my_tensor_var = x + 100 my_primitive_var = 3.14 def true_fn(y): nonlocal my_tensor_var, my_primitive_var return y + my_tensor_var + my_primitive_var def false_fn(y): nonlocal my_tensor_var, my_primitive_var return y - my_tensor_var - my_primitive_var return cond(x.shape[0] > 5, true_fn, false_fn, [x]) ``` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def forward(self, x): my_tensor_var = x + 100 my_primitive_var = 3.14 def true_fn(x, y, z): return x + y + z def false_fn(x, y, z): return x - y - z return cond( x.shape[0] > 5, true_fn, false_fn, [x, my_tensor_var, torch.tensor(my_primitive_var)], ) example_args = (torch.randn(6),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondBranchNonlocalVariables() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, c_lifted_tensor_0: "f32[]", x: "f32[6]"): add: "f32[6]" = torch.ops.aten.add.Tensor(x, 100) lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None detach: "f32[]" = torch.ops.aten.detach.default(lift_fresh_copy); lift_fresh_copy = None add_1: "f32[6]" = torch.ops.aten.add.Tensor(x, add); x = add = None add_2: "f32[6]" = torch.ops.aten.add.Tensor(add_1, detach); add_1 = detach = None return (add_2,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lifted_tensor_0'), target='lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)]) Range constraints: {}
cond_closed_over_variable
原始源代码:
# mypy: allow-untyped-defs import torch from functorch.experimental.control_flow import cond class CondClosedOverVariable(torch.nn.Module): """ torch.cond() supports branches closed over arbitrary variables. """ def forward(self, pred, x): def true_fn(val): return x * 2 def false_fn(val): return x - 2 return cond(pred, true_fn, false_fn, [x + 1]) example_args = (torch.tensor(True), torch.randn(3, 2)) tags = {"torch.cond", "python.closure"} model = CondClosedOverVariable() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, pred: "b8[]", x: "f32[3, 2]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(pred, true_graph_0, false_graph_0, [x]); pred = true_graph_0 = false_graph_0 = x = None getitem: "f32[3, 2]" = cond[0]; cond = None return (getitem,) class true_graph_0(torch.nn.Module): def forward(self, x: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2); x = None return (mul,) class false_graph_0(torch.nn.Module): def forward(self, x: "f32[3, 2]"): sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(x, 2); x = None return (sub,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='pred'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {}
cond_operands
原始源代码:
# mypy: allow-untyped-defs import torch from torch.export import Dim from functorch.experimental.control_flow import cond x = torch.randn(3, 2) y = torch.randn(2) dim0_x = Dim("dim0_x") class CondOperands(torch.nn.Module): """ The operands passed to cond() must be: - a list of tensors - match arguments of `true_fn` and `false_fn` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def forward(self, x, y): def true_fn(x, y): return x + y def false_fn(x, y): return x - y return cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) example_args = (x, y) tags = { "torch.cond", "torch.dynamic-shape", } extra_inputs = (torch.randn(2, 2), torch.randn(2)) dynamic_shapes = {"x": {0: dim0_x}, "y": None} model = CondOperands() torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s0, 2]", y: "f32[2]"): # sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0) gt: "Sym(s0 > 2)" = sym_size_int_1 > 2; sym_size_int_1 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, y]); gt = true_graph_0 = false_graph_0 = x = y = None getitem: "f32[s0, 2]" = cond[0]; cond = None return (getitem,) class true_graph_0(torch.nn.Module): def forward(self, x: "f32[s0, 2]", y: "f32[2]"): add_3: "f32[s0, 2]" = torch.ops.aten.add.Tensor(x, y); x = y = None return (add_3,) class false_graph_0(torch.nn.Module): def forward(self, x: "f32[s0, 2]", y: "f32[2]"): sub_1: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(x, y); x = y = None return (sub_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {s0: VR[0, int_oo]}
cond_predicate
原始源代码:
# mypy: allow-untyped-defs import torch from functorch.experimental.control_flow import cond class CondPredicate(torch.nn.Module): """ The conditional statement (aka predicate) passed to cond() must be one of the following: - torch.Tensor with a single element - boolean expression NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def forward(self, x): pred = x.dim() > 2 and x.shape[2] > 10 return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) example_args = (torch.randn(6, 4, 3),) tags = { "torch.cond", "torch.dynamic-shape", } model = CondPredicate() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[6, 4, 3]"): sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(x); x = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)]) Range constraints: {}
constrain_as_size_example
原始源代码:
# mypy: allow-untyped-defs import torch class ConstrainAsSizeExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check_is_size is used for values that NEED to be used for constructing tensor. """ def forward(self, x): a = x.item() torch._check_is_size(a) torch._check(a <= 5) return torch.zeros((a, 5)) example_args = (torch.tensor(4),) tags = { "torch.dynamic-value", "torch.escape-hatch", } model = ConstrainAsSizeExample() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]"): item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None # sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None ge_3: "Sym(u0 >= 0)" = item >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_default = None le_1: "Sym(u0 <= 5)" = item <= 5 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False); item = None return (zeros,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)]) Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}
constrain_as_value_example
原始源代码:
# mypy: allow-untyped-defs import torch class ConstrainAsValueExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at torch._check and torch._check_is_size APIs. torch._check is used for values that don't need to be used for constructing tensor. """ def forward(self, x, y): a = x.item() torch._check(a >= 0) torch._check(a <= 5) if a < 6: return y.sin() return y.cos() example_args = (torch.tensor(4), torch.randn(5, 5)) tags = { "torch.dynamic-value", "torch.escape-hatch", } model = ConstrainAsValueExample() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]", y: "f32[5, 5]"): item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None ge_1: "Sym(u0 >= 0)" = item >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1: "Sym(u0 <= 5)" = item <= 5; item = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None sin: "f32[5, 5]" = torch.ops.aten.sin.default(y); y = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)]) Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}
decorator
注意
标签:
支持级别:SUPPORTED
原始源代码:
# mypy: allow-untyped-defs import functools import torch def test_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) + 1 return wrapper class Decorator(torch.nn.Module): """ Decorators calls are inlined into the exported function during tracing. """ @test_decorator def forward(self, x, y): return x + y example_args = (torch.randn(3, 2), torch.randn(3, 2)) model = Decorator() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", y: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, y); x = y = None add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1); add = None return (add_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)]) Range constraints: {}
字典
原始源代码:
# mypy: allow-untyped-defs import torch class Dictionary(torch.nn.Module): """ Dictionary structures are inlined and flattened along tracing. """ def forward(self, x, y): elements = {} elements["x2"] = x * x y = y * elements["x2"] return {"y": y} example_args = (torch.randn(3, 2), torch.tensor(4)) tags = {"python.data-structure"} model = Dictionary() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", y: "i64[]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x); x = None mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(y, mul); y = mul = None return (mul_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)]) Range constraints: {}
dynamic_shape_assert
原始源代码:
# mypy: allow-untyped-defs import torch class DynamicShapeAssert(torch.nn.Module): """ A basic usage of python assertion. """ def forward(self, x): # assertion with error message assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" # assertion without error message assert x.shape[0] > 1 return x example_args = (torch.randn(3, 2),) tags = {"python.assert"} model = DynamicShapeAssert() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): return (x,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='x'), target=None)]) Range constraints: {}
dynamic_shape_constructor
原始源代码:
# mypy: allow-untyped-defs import torch class DynamicShapeConstructor(torch.nn.Module): """ Tensor constructors should be captured with dynamic shape inputs rather than being baked in with static shape. """ def forward(self, x): return torch.zeros(x.shape[0] * 2) example_args = (torch.randn(3, 2),) tags = {"torch.dynamic-shape"} model = DynamicShapeConstructor() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): zeros: "f32[6]" = torch.ops.aten.zeros.default([6], device = device(type='cpu'), pin_memory = False) return (zeros,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)]) Range constraints: {}
dynamic_shape_if_guard
原始源代码:
# mypy: allow-untyped-defs import torch class DynamicShapeIfGuard(torch.nn.Module): """ `if` statement with backed dynamic shape predicate will be specialized into one particular branch and generate a guard. However, export will fail if the the dimension is marked as dynamic shape from higher level API. """ def forward(self, x): if x.shape[0] == 3: return x.cos() return x.sin() example_args = (torch.randn(3, 2, 2),) tags = {"torch.dynamic-shape", "python.control-flow"} model = DynamicShapeIfGuard() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2, 2]"): cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x); x = None return (cos,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)]) Range constraints: {}
dynamic_shape_map
原始源代码:
# mypy: allow-untyped-defs import torch from functorch.experimental.control_flow import map class DynamicShapeMap(torch.nn.Module): """ functorch map() maps a function over the first tensor dimension. """ def forward(self, xs, y): def body(x, y): return x + y return map(body, xs, y) example_args = (torch.randn(3, 2), torch.randn(2)) tags = {"torch.dynamic-shape", "torch.map"} model = DynamicShapeMap() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, xs: "f32[3, 2]", y: "f32[2]"): body_graph_0 = self.body_graph_0 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y]); body_graph_0 = xs = y = None getitem: "f32[3, 2]" = map_impl[0]; map_impl = None return (getitem,) class body_graph_0(torch.nn.Module): def forward(self, xs: "f32[2]", y: "f32[2]"): add: "f32[2]" = torch.ops.aten.add.Tensor(xs, y); xs = y = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {}
dynamic_shape_slicing
原始源代码:
# mypy: allow-untyped-defs import torch class DynamicShapeSlicing(torch.nn.Module): """ Slices with dynamic shape arguments should be captured into the graph rather than being baked in. """ def forward(self, x): return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] example_args = (torch.randn(3, 2),) tags = {"torch.dynamic-shape"} model = DynamicShapeSlicing() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 1); x = None slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2); slice_1 = None return (slice_2,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)]) Range constraints: {}
dynamic_shape_view
原始源代码:
# mypy: allow-untyped-defs import torch class DynamicShapeView(torch.nn.Module): """ Dynamic shapes should be propagated to view arguments instead of being baked into the exported graph. """ def forward(self, x): new_x_shape = x.size()[:-1] + (2, 5) x = x.view(*new_x_shape) return x.permute(0, 2, 1) example_args = (torch.randn(10, 10),) tags = {"torch.dynamic-shape"} model = DynamicShapeView() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[10, 10]"): view: "f32[10, 2, 5]" = torch.ops.aten.view.default(x, [10, 2, 5]); x = None permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]); view = None return (permute,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)]) Range constraints: {}
fn_with_kwargs
原始源代码:
# mypy: allow-untyped-defs import torch class FnWithKwargs(torch.nn.Module): """ Keyword arguments are not supported at the moment. """ def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs): out = pos0 for arg in tuple0: out = out * arg for arg in myargs: out = out * arg out = out * mykw0 out = out * mykwargs["input0"] * mykwargs["input1"] return out example_args = ( torch.randn(4), (torch.randn(4), torch.randn(4)), *[torch.randn(4), torch.randn(4)] ) example_kwargs = { "mykw0": torch.randn(4), "input0": torch.randn(4), "input1": torch.randn(4), } tags = {"python.data-structure"} model = FnWithKwargs() torch.export.export(model, example_args, example_kwargs)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, pos0: "f32[4]", tuple0_0: "f32[4]", tuple0_1: "f32[4]", myargs_0: "f32[4]", myargs_1: "f32[4]", mykw0: "f32[4]", input0: "f32[4]", input1: "f32[4]"): mul: "f32[4]" = torch.ops.aten.mul.Tensor(pos0, tuple0_0); pos0 = tuple0_0 = None mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, tuple0_1); mul = tuple0_1 = None mul_2: "f32[4]" = torch.ops.aten.mul.Tensor(mul_1, myargs_0); mul_1 = myargs_0 = None mul_3: "f32[4]" = torch.ops.aten.mul.Tensor(mul_2, myargs_1); mul_2 = myargs_1 = None mul_4: "f32[4]" = torch.ops.aten.mul.Tensor(mul_3, mykw0); mul_3 = mykw0 = None mul_5: "f32[4]" = torch.ops.aten.mul.Tensor(mul_4, input0); mul_4 = input0 = None mul_6: "f32[4]" = torch.ops.aten.mul.Tensor(mul_5, input1); mul_5 = input1 = None return (mul_6,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='pos0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tuple0_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tuple0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='myargs_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='myargs_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='mykw0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_6'), target=None)]) Range constraints: {}
list_contains
原始源代码:
# mypy: allow-untyped-defs import torch class ListContains(torch.nn.Module): """ List containment relation can be checked on a dynamic shape or constants. """ def forward(self, x): assert x.size(-1) in [6, 2] assert x.size(0) not in [4, 5, 6] assert "monkey" not in ["cow", "pig"] return x + x example_args = (torch.randn(3, 2),) tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"} model = ListContains() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, x); x = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
list_unpack
原始源代码:
# mypy: allow-untyped-defs from typing import List import torch class ListUnpack(torch.nn.Module): """ Lists are treated as static construct, therefore unpacking should be erased after tracing. """ def forward(self, args: List[torch.Tensor]): """ Lists are treated as static construct, therefore unpacking should be erased after tracing. """ x, *y = args return x + y[0] example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],) tags = {"python.control-flow", "python.data-structure"} model = ListUnpack() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, args_0: "f32[3, 2]", args_1: "i64[]", args_2: "i64[]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(args_0, args_1); args_0 = args_1 = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
nested_function
原始源代码:
# mypy: allow-untyped-defs import torch class NestedFunction(torch.nn.Module): """ Nested functions are traced through. Side effects on global captures are not supported though. """ def forward(self, a, b): x = a + b z = a - b def closure(y): nonlocal x x += 1 return x * y + z return closure(x) example_args = (torch.randn(3, 2), torch.randn(2)) tags = {"python.closure"} model = NestedFunction() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, a: "f32[3, 2]", b: "f32[2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(a, b) sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(a, b); a = b = None add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1); add = None mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_1, add_1); add_1 = None add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub); mul = sub = None return (add_2,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)]) Range constraints: {}
null_context_manager
原始源代码:
# mypy: allow-untyped-defs import contextlib import torch class NullContextManager(torch.nn.Module): """ Null context manager in Python will be traced out. """ def forward(self, x): """ Null context manager in Python will be traced out. """ ctx = contextlib.nullcontext() with ctx: return x.sin() + x.cos() example_args = (torch.randn(3, 2),) tags = {"python.context-manager"} model = NullContextManager() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): sin: "f32[3, 2]" = torch.ops.aten.sin.default(x) cos: "f32[3, 2]" = torch.ops.aten.cos.default(x); x = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
pytree_flatten
注意
标签:
支持级别:SUPPORTED
原始源代码:
# mypy: allow-untyped-defs import torch from torch.utils import _pytree as pytree class PytreeFlatten(torch.nn.Module): """ Pytree from PyTorch can be captured by TorchDynamo. """ def forward(self, x): y, spec = pytree.tree_flatten(x) return y[0] + 1 example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},), model = PytreeFlatten() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x_0_1: "f32[3, 2]", x_0_2: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x_0_1, 1); x_0_1 = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x_0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x_0_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
scalar_output
原始源代码:
# mypy: allow-untyped-defs import torch from torch.export import Dim x = torch.randn(3, 2) dim1_x = Dim("dim1_x") class ScalarOutput(torch.nn.Module): """ Returning scalar values from the graph is supported, in addition to Tensor outputs. Symbolic shapes are captured and rank is specialized. """ def __init__(self) -> None: super().__init__() def forward(self, x): return x.shape[1] + 1 example_args = (x,) tags = {"torch.dynamic-shape"} dynamic_shapes = {"x": {1: dim1_x}} model = ScalarOutput() torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, s0]"): # sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 1); x = None add: "Sym(s0 + 1)" = sym_size_int_1 + 1; sym_size_int_1 = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)]) Range constraints: {s0: VR[0, int_oo]}
specialized_attribute
注意
标签:
支持级别:SUPPORTED
原始源代码:
# mypy: allow-untyped-defs from enum import Enum import torch class Animal(Enum): COW = "moo" class SpecializedAttribute(torch.nn.Module): """ Model attributes are specialized. """ def __init__(self) -> None: super().__init__() self.a = "moo" self.b = 4 def forward(self, x): if self.a == Animal.COW.value: return x * x + self.b else: raise ValueError("bad") example_args = (torch.randn(3, 2),) model = SpecializedAttribute() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x); x = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, 4); mul = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
static_for_loop
原始源代码:
# mypy: allow-untyped-defs import torch class StaticForLoop(torch.nn.Module): """ A for loop with constant number of iterations should be unrolled in the exported graph. """ def forward(self, x): ret = [] for i in range(10): # constant ret.append(i + x) return ret example_args = (torch.randn(3, 2),) tags = {"python.control-flow"} model = StaticForLoop() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 0) add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1) add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 2) add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 3) add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4) add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 5) add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 6) add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 7) add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 8) add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 9); x = None return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)]) Range constraints: {}
static_if
原始源代码:
# mypy: allow-untyped-defs import torch class StaticIf(torch.nn.Module): """ `if` statement with static predicate value should be traced through with the taken branch. """ def forward(self, x): if len(x.shape) == 3: return x + torch.ones(1, 1, 1) return x example_args = (torch.randn(3, 2, 2),) tags = {"python.control-flow"} model = StaticIf() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2, 2]"): ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False) add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(x, ones); x = ones = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
tensor_setattr
原始源代码:
# mypy: allow-untyped-defs import torch class TensorSetattr(torch.nn.Module): """ setattr() call onto tensors is not supported. """ def forward(self, x, attr): setattr(x, attr, torch.randn(3, 2)) return x + 4 example_args = (torch.randn(3, 2), "attr") tags = {"python.builtin"} model = TensorSetattr() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", attr): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4); x = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=ConstantArgument(name='attr', value='attr'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
type_reflection_method
原始源代码:
# mypy: allow-untyped-defs import torch class A: @classmethod def func(cls, x): return 1 + x class TypeReflectionMethod(torch.nn.Module): """ type() calls on custom objects followed by attribute accesses are not allowed due to its overly dynamic nature. """ def forward(self, x): a = A() return type(a).func(x) example_args = (torch.randn(3, 4),) tags = {"python.builtin"} model = TypeReflectionMethod() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 4]"): add: "f32[3, 4]" = torch.ops.aten.add.Tensor(x, 1); x = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]) Range constraints: {}
user_input_mutation
原始源代码:
# mypy: allow-untyped-defs import torch class UserInputMutation(torch.nn.Module): """ Directly mutate user input in forward """ def forward(self, x): x.mul_(2) return x.cos() example_args = (torch.randn(3, 2),) tags = {"torch.mutation"} model = UserInputMutation() torch.export.export(model, example_args)
结果:
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2); x = None cos: "f32[3, 2]" = torch.ops.aten.cos.default(mul) return (mul, cos) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_INPUT_MUTATION: 6>, arg=TensorArgument(name='mul'), target='x'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)]) Range constraints: {}
尚未支持
dynamic_shape_round
原始源代码:
# mypy: allow-untyped-defs import torch from torch._export.db.case import SupportLevel from torch.export import Dim class DynamicShapeRound(torch.nn.Module): """ Calling round on dynamic shapes is not supported. """ def forward(self, x): return x[: round(x.shape[0] / 2)] x = torch.randn(3, 2) dim0_x = Dim("dim0_x") example_args = (x,) tags = {"torch.dynamic-shape", "python.builtin"} support_level = SupportLevel.NOT_SUPPORTED_YET dynamic_shapes = {"x": {0: dim0_x}} model = DynamicShapeRound() torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)
结果:
Unsupported: Constraints violated (dim0_x)! For more information, run with TORCH_LOGS="+dynamic".
model_attr_mutation
原始源代码:
# mypy: allow-untyped-defs import torch from torch._export.db.case import SupportLevel class ModelAttrMutation(torch.nn.Module): """ Attribute mutation is not supported. """ def __init__(self) -> None: super().__init__() self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)] def recreate_list(self): return [torch.zeros(3, 2), torch.zeros(3, 2)] def forward(self, x): self.attr_list = self.recreate_list() return x.sum() + self.attr_list[0].sum() example_args = (torch.randn(3, 2),) tags = {"python.object-model"} support_level = SupportLevel.NOT_SUPPORTED_YET model = ModelAttrMutation() torch.export.export(model, example_args)
结果:
AssertionError: Mutating module attribute attr_list during export.
optional_input
原始源代码:
# mypy: allow-untyped-defs import torch from torch._export.db.case import SupportLevel class OptionalInput(torch.nn.Module): """ Tracing through optional input is not supported yet """ def forward(self, x, y=torch.randn(2, 3)): if y is not None: return x + y return x example_args = (torch.randn(2, 3),) tags = {"python.object-model"} support_level = SupportLevel.NOT_SUPPORTED_YET model = OptionalInput() torch.export.export(model, example_args)
结果:
Unsupported: Tracing through optional input is not supported yet
unsupported_operator
原始源代码:
# mypy: allow-untyped-defs import torch from torch._export.db.case import SupportLevel class TorchSymMin(torch.nn.Module): """ torch.sym_min operator is not supported in export. """ def forward(self, x): return x.sum() + torch.sym_min(x.size(0), 100) example_args = (torch.randn(3, 2),) tags = {"torch.operator"} support_level = SupportLevel.NOT_SUPPORTED_YET model = TorchSymMin() torch.export.export(model, example_args)
结果:
Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f98c3497040>