torch.autograd.Function.forward
- static Function.forward(*args, **kwargs)
-
定义自定义自动微分函数的前向传播。
此函数需要被所有子类重写。有兩種方法可以定義前向傳播:
用法 1(结合前向和 ctx):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
-
它必须将上下文 ctx 作为第一个参数,并且可以跟随任意数量的其他参数(如张量或其它类型)。
用法 2(分开 forward 和 ctx):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
-
“前向”函数不再接收ctx参数。
-
相反,你还必须重写
torch.autograd.Function.setup_context()
静态方法来设置ctx
对象。其中,output
是前向传播的输出,而inputs
是一个包含前向传播输入的元组。 -
更多细节请参见扩展torch.autograd
上下文可以用来存储任意数据,并在反向传递期间检索这些数据。张量不应直接存储在ctx上(尽管当前还没有强制执行以保持后向兼容性)。相反,如果张量用于
backward
(等同于vjp
),则应使用ctx.save_for_backward()
进行保存;如果张量用于jvp
,则应使用ctx.save_for_forward()
进行保存。- 返回类型
-