控制流程 - 条件
torch.cond 是一个结构化的控制流操作符,可以用来指定类似于 if-else 的控制流。从逻辑上讲,它相当于以下实现方式。
def cond( pred: Union[bool, torch.Tensor], true_fn: Callable, false_fn: Callable, operands: Tuple[torch.Tensor] ): if pred: return true_fn(*operands) else: return false_fn(*operands)
其独特之处在于能够表达数据依赖的控制流:它转化为一个条件操作符(torch.ops.higher_order.cond),保留了谓词、真函数和假函数。这使得在编写和部署模型时,可以根据输入或张量运算中间输出的值或形状灵活地改变模型架构。
警告
torch.cond 是 PyTorch 中的一个原型功能,目前对输入和输出类型的支持有限,并且不支持训练。请期待未来版本的 PyTorch 中更稳定的实现。更多关于特性分类的信息,请参阅:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
例子
以下是一个使用 cond 按输入形状进行分支的示例:
import torch def true_fn(x: torch.Tensor): return x.cos() + x.sin() def false_fn(x: torch.Tensor): return x.sin() class DynamicShapeCondPredicate(torch.nn.Module): """ A basic usage of cond based on dynamic shape predicate. """ def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,)) dyn_shape_mod = DynamicShapeCondPredicate()
我们可以立即运行模型,并预期结果会根据输入形状的不同而变化。
inp = torch.randn(3) inp2 = torch.randn(5) assert torch.equal(dyn_shape_mod(inp), false_fn(inp)) assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
我们可以将模型导出来进行进一步的转化和部署:
inp = torch.randn(4, 3) dim_batch = torch.export.Dim("batch", min=2) ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}}) print(ep)
这为我们提供了如下所示的导出程序:
class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0) gt: Sym(s0 > 4) = sym_size > 4; sym_size = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None return (conditional,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None return add class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return sin
请注意,torch.cond 被转换为 torch.ops.higher_order.cond。它的谓词变为基于输入形状的符号表达式,而分支函数则成为顶级图模块中的两个子图属性。
这里有一个例子展示了如何表达基于数据的控制流程。
class DataDependentCondPredicate(torch.nn.Module): """ A basic usage of cond based on data dependent predicate. """ def __init__(self): super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
导出的程序如下:
class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sum_1: f32[] = torch.ops.aten.sum.default(arg0_1) gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None return (conditional,) class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1) sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None return add class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[s0, 3]): sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return sin
torch.ops.higher_order.cond 的特性
torch.ops.higher_order.cond 有一些有用的性质:
-
- 对于谓词:
-
-
谓词的动态特性得以保留(例如,上述示例中的gt)
-
如果用户程序中的谓词是常量(例如 Python 中的布尔常量),操作符的 pred 将是一个常量。
-
-
- 分支:
-
-
输入和输出的签名将以展平的元组形式呈现。
-
它们是torch.fx.GraphModule。
-
原函数中的闭包变成了显式输入,不再有闭包。
-
不得对输入或全局变量进行任何修改。
-
-
- 对操作数:
-
-
它也将是一个扁平化的元组。
-
-
在用户程序中,torch.cond 的嵌套变为嵌套图模块。
API参考
- torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands)
-
根据条件选择应用true_fn或false_fn。
警告
torch.cond 是 PyTorch 中的一个原型功能,目前对输入和输出类型的支持有限,并且不支持训练。请期待未来版本的 PyTorch 中更稳定的实现。更多关于特性分类的信息,请参阅:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
cond 是一个结构化的控制流操作符。它类似于 Python 的 if 语句,但对 true_fn、false_fn 和 operands 有一些限制,这些限制使得它可以被 torch.compile 和 torch.export 捕获。
如果cond的参数满足约束条件,那么cond等同于以下内容:
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)
- 参数
-
-
pred (Union[bool, torch.Tensor]) – 一个布尔表达式或包含单个元素的张量,用于指示要应用哪个分支函数。
-
true_fn (Callable) – 一个在被追踪作用域内的可调用函数(a -> b)。
-
false_fn (Callable) – 在被追踪的作用域内的可调用函数(a -> b)。真分支和假分支的输入和输出必须保持一致,即输入参数必须相同,并且输出类型和形状也必须相同。
-
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 输入元组,包含传递给 true/false 函数的嵌套字典、列表或张量。
-
示例:
def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
- 限制:
-
-
条件语句(即 pred)必须满足以下其中一个要求:
-
它是一个只包含一个元素的torch.Tensor,数据类型为
torch.bool
-
它是一个布尔表达式,例如 x.shape[0] > 10 或者 x.dim() > 1 and x.shape[1] > 10
-
-
分支函数(即true_fn和false_fn)必须满足以下所有条件:
-
函数签名必须与运算符匹配。
-
该函数必须返回一个与原张量具有相同属性(如形状和数据类型)的新张量。
-
该函数不能对输入或全局变量进行就地修改。(注意:在分支中可以使用add_等就地张量操作来处理中间结果)
-
-
警告
时间限制:
-
分支的输出必须是单一的Tensor。未来将会支持张量的Pytree。