torch.autograd.Function.backward

static Function.backward(ctx, *grad_outputs)

定义一个公式,用于通过反向模式自动微分来区分操作。

此函数需要在所有子类中被重写。(定义此函数相当于定义了vjp函数。)

它必须接受一个上下文 ctx 作为第一个参数,然后是与 forward() 返回的输出数量相同的参数(前向函数中非张量输出将传递 None),并且应该返回与 forward() 输入数量相等的张量。每个参数是相对于给定输出的梯度,每个返回值应该是相对于相应输入的梯度。如果一个输入不是 Tensor 或者是一个不需要计算梯度的 Tensor,则可以为该输入传递 None 作为梯度。

可以使用上下文来检索在前向传递中保存的张量。它还有一个属性ctx.needs_input_grad,这是一个布尔元组,表示每个输入是否需要梯度计算。例如,如果backward() 需要计算forward() 的第一个输入相对于输出的梯度,则ctx.needs_input_grad[0] = True

返回类型

Any

本页目录