torch.func.functionalize
- torch.func.functionalize(func, *, remove='mutations')
-
functionalize 是一种转换,可用于从函数中移除(中间)变异和别名,同时保持函数的语义不变。
functionalize(func)
返回一个与func
具有相同语义的新函数,但移除了所有中间的变量修改。每个在中间张量上执行的就地操作:intermediate.foo_()
被替换为其非就地操作等价形式:intermediate_updated = intermediate.foo()
。使用 functionalize 可以很方便地将 PyTorch 程序发送到那些无法轻松处理变异或别名操作的后端或编译器。
- 参数
-
-
func (Callable) – 一个可以接受一个或多个参数的 Python 函数。
-
remove (str) – 一个可选的字符串参数,其值可以是“mutations”或“mutations_and_views”。如果设置为“mutations”,则所有修改操作符将被替换为其非修改版本。如果设置为“mutations_and_views”,除了替换修改操作符外,还会替换所有的别名操作符。默认值:‘mutations’。
-
- 返回值
-
返回一个新的“功能化”函数。该函数接受与
func
相同的输入,并具有相同的行为,但会移除函数中对中间张量进行的所有修改和可选的别名操作。 - 返回类型
functionalize 也会移除函数输入上的变异操作(和视图)。但是为了保持语义,它会在转换完成后,检测哪些张量输入本应被修改,并在必要时将新的数据复制回去。
示例:
>>> import torch >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.func import functionalize >>> >>> # A function that uses mutations and views, but only on intermediate tensors. >>> def f(a): ... b = a + 1 ... c = b.view(-1) ... c.add_(1) ... return b ... >>> inpt = torch.randn(2) >>> >>> out1 = f(inpt) >>> out2 = functionalize(f)(inpt) >>> >>> # semantics are the same (outputs are equivalent) >>> print(torch.allclose(out1, out2)) True >>> >>> f_traced = make_fx(f)(inpt) >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> >>> print(f_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]) add_ = torch.ops.aten.add_(view, 1); view = None return add >>> print(f_no_mutations_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]); add = None add_1 = torch.ops.aten.add(view, 1); view = None view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None return view_1 >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view_copy = torch.ops.aten.view_copy(add, [-1]); add = None add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None return view_copy_1 >>> # A function that mutates its input tensor >>> def f(a): ... b = a.view(-1) ... b.add_(1) ... return a ... >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> # >>> # All mutations and views have been removed, >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input >>> # after the function has completed. >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): view_copy = torch.ops.aten.view_copy(a_1, [-1]) add = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None return view_copy_1
- 有几个“失败模式”值得关注,这些模式与将代码转换为函数式编程相关:
-
-
与其他 torch.func 转换一样,functionalize() 不适用于直接使用 .backward() 的函数。同样地,torch.autograd.grad 也不适用。如果你想使用 autograd,可以使用 functionalize(grad(f)) 直接计算梯度。
-
与其他 torch.func 转换一样,functionalize() 不支持全局状态。如果你在一个会访问非局部状态视图或修改的函数上调用 functionalize(f),功能化将直接失效,并将视图/修改调用传递给后端。一种解决方法是确保任何非局部状态创建都被包裹到一个更大的函数中,然后在这个大函数上进行 functionalize 调用。
-
resize_() 有一些限制:只有在被调整大小的张量不是视图的情况下,functionalize 才会对使用 resize_() 的程序生效。
-
as_strided() 有一些限制:对于生成具有重叠内存的张量的 as_strided() 调用,不能使用 functionalize。
-
最后,理解函数化的一个有用思维模型是:大多数用户编写的 PyTorch 程序都是使用公共的 torch API 编写。在执行时,torch 操作符通常会被分解为内部的 C++ “ATen” API。函数化的逻辑完全发生在 ATen 层级上。它知道如何将每个带有别名的操作符从 ATen 映射到其非别名等价操作符(例如
tensor.view({-1})
->at::view_copy(tensor, {-1})
),以及如何将每个修改状态的操作符从 ATen 映射到其非修改等价操作符(例如tensor.add_(1)
->at::add(tensor, -1)
)。同时,它会跟踪别名和修改状态的信息以确定何时进行修复。关于哪些 ATen 操作符是带有别名或修改状态的所有信息都来自https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml。