torch.cond

torch.cond(pred, true_fn, false_fn, operands)

根据条件选择应用true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的一个原型功能,目前对输入和输出类型的支持有限,并且不支持训练。请期待未来版本的 PyTorch 中更稳定的实现。更多关于特性分类的信息,请参阅:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

cond 是一个结构化的控制流操作符。它类似于 Python 的 if 语句,但对 true_fnfalse_fnoperands 有一些限制,这些限制使得它可以被 torch.compile 和 torch.export 捕获。

如果cond的参数满足约束条件,那么cond等同于以下内容:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
参数
  • pred (Union[bool, torch.Tensor]) – 一个布尔表达式或包含单个元素的张量,用于指示要应用哪个分支函数。

  • true_fn (Callable) – 一个在被追踪作用域内的可调用函数(a -> b)。

  • false_fn (Callable) – 在被追踪的作用域内的可调用函数(a -> b)。真分支和假分支的输入和输出必须保持一致,即输入参数必须相同,并且输出类型和形状也必须相同。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 输入元组,传递给 true/false 函数。

示例:

def true_fn(x: torch.Tensor):
    return x.cos()
def false_fn(x: torch.Tensor):
    return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制:
  • 条件语句(即 pred)必须满足以下其中一个要求:

    • 它是一个只包含一个元素的torch.Tensor,数据类型为torch.bool

    • 它是一个布尔表达式,例如 x.shape[0] > 10 或者 x.dim() > 1 and x.shape[1] > 10

  • 分支函数(即true_fnfalse_fn)必须满足以下所有条件:

    • 函数签名必须与运算符匹配。

    • 该函数必须返回一个与原张量具有相同属性(如形状和数据类型)的新张量。

    • 该函数不能对输入或全局变量进行就地修改。(注意:在分支中可以使用add_等就地张量操作来处理中间结果)

警告

时间限制:

  • 分支的输出必须是单一的Tensor。未来将会支持张量的Pytree。

本页目录