torch.vmap

torch.vmap(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)

vmap 是一个向量化的映射函数;vmap(func) 返回一个新的函数,该新函数会将 func 应用于输入数据的一个维度。从语义上讲,vmap 会将映射操作推入由 func 调用的 PyTorch 操作中,从而实现对这些操作的向量化。

vmap 适用于处理批量维度:可以编写一个在单个示例上运行的函数 func,然后通过 vmap(func) 将其转换为可以在批量数据上运行的函数。当与自动微分结合使用时,vmap 还可用于计算批量梯度。

注意

torch.vmap()torch.func.vmap() 是别名关系,方便使用。请选择你喜欢的一个。

参数
  • func函数)– 一个接受一个或多个参数的Python函数,必须返回一个或多个张量。

  • in_dims (int嵌套结构) – 指定输入数据的哪个维度需要进行映射。in_dims 应该与输入数据具有相同的结构。如果某个特定输入的 in_dim 为 None,则表示该输入没有需要映射的维度。默认值:0。

  • out_dims (intTuple[int]) – 指定映射维度在输出中的位置。如果 out_dims 是一个元组,则它应为每个输出包含一个元素。默认值:0。

  • randomness (str) – 指定此 vmap 中的随机性在批次之间是否相同或不同。如果设置为 'different',每个批次将有不同的随机性;若设置为 'same',所有批次中的随机性保持一致;若设置为 'error',调用任何随机函数时将会引发错误。默认值:'error'。注意:此标志仅适用于 PyTorch 的随机操作,并不适用于 Python 的 random 模块或 numpy 随机性。

  • chunk_size (Noneint) – 如果为 None(默认值),则对输入应用单个 vmap。如果不为 None,则每次计算 chunk_size 个样本的 vmap。注意,chunk_size=1 等同于使用 for 循环来计算 vmap。如果你在计算 vmap 时遇到内存问题,请尝试设置非 None 的 chunk_size。

返回值

返回一个新的“批处理”函数。它接受与func相同的输入参数,但每个输入参数在in_dims指定的位置上增加了一个维度。它的输出与func相同,但在out_dims指定的位置上为每个输出增加了额外的维度。

返回类型

Callable

使用vmap()的一个例子是计算批量点积。PyTorch 没有提供批量 torch.dot API;不要在文档中徒劳地寻找,而是使用vmap()来构建一个新的函数。

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)

vmap() 可以帮助隐藏批次维度,从而使模型的创建更加简单。

>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>>     # Very simple linear model with activation
>>>     return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)

vmap() 还可以帮助向量化之前难以批量处理或不可能批量处理的计算,例如高阶梯度计算。PyTorch 的自动微分引擎可以计算 vjps(向量-Jacobian 积)。对于某些函数 f: R^N -> R^N,通常需要调用 N 次 autograd.grad 来计算完整的 Jacobian 矩阵,每次调用对应于 Jacobian 的一行。使用 vmap(),我们可以向量化整个计算过程,在一次调用 autograd.grad 中完成 Jacobian 的计算。

>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>>                  for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>>     return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)

vmap() 还可以嵌套使用,生成包含多个批处理维度的输出。

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.vmap(torch.dot))  # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3]

如果输入没有在第一维进行批量处理,in_dims 指定每个输入应沿哪一维度进行批量处理。

>>> torch.dot                            # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1)  # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)   # output is [5] instead of [2] if batched along the 0th dimension

如果有多个输入,并且每个输入在不同的维度上进行批处理,那么 in_dims 必须是一个元组,其中包含每个输入的批处理维度。

>>> torch.dot                            # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None))  # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None

如果输入是 Python 结构体,in_dims 必须是一个包含与输入形状相匹配的结构体的元组:

>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input)

默认情况下,输出在第一维度上进行批量处理。然而,可以使用out_dims沿任何维度进行批量处理。

>>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]

对于任何使用 kwargs 的函数,返回的函数不会批量处理 kwargs,但会接受 kwargs。

>>> x = torch.randn([2, 5])
>>> def fn(x, scale=4.):
>>>   return x * scale
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]

注意

vmap 默认不提供通用的自动批处理功能,也不支持处理可变长度的序列。

本页目录