torch.func.grad_and_value
- torch.func.grad_and_value(func, argnums=0, has_aux=False)
-
返回一个函数,用于计算梯度和原生正向计算的结果元组。
- 参数
- 返回值
-
用于计算相对于其输入和正向计算的梯度元组的函数。默认情况下,该函数返回第一个参数的梯度张量元组及原始计算结果。如果指定
has_aux
为True
,则返回一个包含梯度元组和带有辅助对象的正向计算元组。若argnums
是整数元组,则返回一个包含相对于每个argnums
值的输出梯度元组及正向计算结果。 - 返回类型
查看
grad()
的示例