torch.func.grad_and_value

torch.func.grad_and_value(func, argnums=0, has_aux=False)

返回一个函数,用于计算梯度和原生正向计算的结果元组。

参数
  • func (Callable) – 一个接受一个或多个参数的 Python 函数,必须返回单个元素的张量。如果指定的 has_aux 等于 True,函数可以返回一个包含输出张量和其他辅助对象的元组:(output, aux)

  • argnums (intTuple[int]) – 指定用于计算梯度的参数。可以是单个整数或整数元组形式的 argnums。默认值:0。

  • has_aux (bool) – 标志,表示func 返回一个张量和其他辅助对象: (output, aux)。默认值为 False。

返回值

用于计算相对于其输入和正向计算的梯度元组的函数。默认情况下,该函数返回第一个参数的梯度张量元组及原始计算结果。如果指定 has_auxTrue,则返回一个包含梯度元组和带有辅助对象的正向计算元组。若 argnums 是整数元组,则返回一个包含相对于每个 argnums 值的输出梯度元组及正向计算结果。

返回类型

Callable

查看grad() 的示例

本页目录