torch.autograd.Function.vmap
- 静态Function.vmap(info, in_dims, *args)[源代码]
-
在
torch.vmap()
中定义此 autograd.Function 的行为。为了使
torch.autograd.Function()
支持torch.vmap()
,你必须重写这个静态方法或将generate_vmap_rule
设置为True
(但不能同时进行)。如果你选择覆盖这个静态方法:它必须接受
-
将一个
info
对象作为第一个参数。其中,info.batch_size
指定了被vmapped的维度大小,而info.randomness
是传递给torch.vmap()
的随机性选项。 -
作为第二个参数传递一个
in_dims
元组。对于args
中的每个参数,in_dims
包含一个对应的Optional[int]
值。如果参数不是张量或不被vmapped处理,则该值为None
;否则,它是一个整数,表示正在被vmapped的张量的维度。 -
*args
,这与传递给forward()
方法的参数相同。
vmap 静态方法的返回值是一个元组
(output, out_dims)
。类似于in_dims
,out_dims
应该与output
具有相同的结构,并且每个输出都包含一个out_dim
,以指定该输出是否具有 vmap 维度及其索引。请参见使用 autograd.Function 扩展 torch.func以获取更多详细信息。
-