torch.autograd.graph.increment_version
- torch.autograd.graph.increment_version(tensor)[源代码]
-
更新自动微分元数据,以跟踪给定张量是否被就地修改。
这有助于在自动微分引擎中进行更准确的错误检查。PyTorch函数和在自定义Function中适当调用mark_dirty()时会自动完成此操作,因此你只需要在执行PyTorch不知道的就地操作(inplace operation)时显式调用此方法。例如,一个自定义内核读取Tensor的数据指针,并根据该指针就地修改内存。此函数可以接受单个张量或张量列表。
注意,对于单个就地操作来说,多次增加版本计数器是没有问题的。
注意,如果你传入的是在 torch.inference_mode() 下创建的张量,我们不会增加其版本计数器(因为该张量没有版本计数器)。