torch.nn.utils.stateless.functional_call

torch.nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs=None, *, tie_weights=True, strict=False)[源代码]

通过对模块的参数和缓冲区进行替换,使用提供的参数和缓冲区来进行功能调用。

警告

此 API 自 PyTorch 2.0 起已弃用,并将在未来的 PyTorch 版本中被移除。请使用torch.func.functional_call()作为替代,它是该 API 的直接替换方案。

注意

如果模块具有活动的参数化,可以在 parameters_and_buffers 参数中传递一个名称与常规参数名相同的值来完全禁用该参数化。如果你想将参数化函数应用于传递的值,请将键设置为{子模块名}.parametrizations.{参数名}.original

注意

如果模块对参数/缓冲区执行就地操作,这些变化将会在parameters_and_buffers中反映出来。

示例:

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # does self.foo = self.foo + 1
>>> print(mod.foo)  # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo)  # tensor(0.)
>>> print(a['foo'])  # tensor(1.)

注意

如果模块具有绑定权重,则functional_call 是否遵守这些绑定取决于 tie_weights 标志。

示例:

>>> a = {'foo': torch.zeros(())}
>>> mod = Foo()  # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo)  # tensor(1.)
>>> mod(torch.zeros(()))  # tensor(2.)
>>> functional_call(mod, a, torch.zeros(()))  # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False)  # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
参数
  • module (torch.nn.Module) – 需要调用的模块

  • parameters_and_buffers (dict of str and Tensor) – 在模块调用时使用的参数。

  • args (Anytuple) – 传递给模块调用的参数。如果不是元组,则视为单个参数。

  • kwargs (dict) – 模块调用所需的关键词参数

  • tie_weights (bool, 可选) – 如果为 True,原始模型中绑定的参数和缓冲区在重新参数化版本中也将被视为绑定。因此,如果传递给这些绑定的不同值,则会引发错误。如果为 False,则除非为两个权重传递相同的值,否则不会尊重原始绑定的参数和缓冲区。默认值:True。

  • strict (bool, 可选) – 如果为 True,传入的参数和缓冲区必须与原始模块中的完全匹配。否则,如果有任何缺失或额外的关键字,则会引发错误。默认值:False。

返回值

调用 module 的结果。

返回类型

任何

本页目录