PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

逐样本梯度

什么是它?

逐样本梯度计算是指计算一批数据中每个样本的梯度。它在差分隐私、元学习和优化研究中是一个有用的量。

本教程需要 PyTorch 2.0.0 或更高版本。

importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
torch.manual_seed(0)

# Here's a simple CNN and loss function:

classSimpleCNN(nn.Module):
    def__init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    defforward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

defloss_fn(predictions, targets):
    return F.nll_loss(predictions, targets)

让我们生成一批虚拟数据,并假设我们正在处理一个MNIST数据集。这些虚拟图像的尺寸为28×28,我们使用大小为64的小批量。

device = 'cuda'

num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)

targets = torch.randint(10, (64,), device=device)

在常规的模型训练中,我们会将小批量数据通过模型进行前向传播,然后调用 .backward() 来计算梯度。这将生成整个小批量数据的“平均”梯度:

model = SimpleCNN().to(device=device)
predictions = model(data)  # move the entire mini-batch through the model

loss = loss_fn(predictions, targets)
loss.backward()  # back propagate the 'average' gradient of this mini-batch

与上述方法相反,逐样本梯度计算等同于:

  • 对数据的每个单独样本执行前向传播和反向传播,以获取单个(每个样本的)梯度。
defcompute_grad(sample, target):
    sample = sample.unsqueeze(0)  # prepend batch dimension for processing
    target = target.unsqueeze(0)

    prediction = model(sample)
    loss = loss_fn(prediction, target)

    return torch.autograd.grad(loss, list(model.parameters()))


defcompute_sample_grads(data, targets):
""" manually process each sample with per sample gradient """
    sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
    sample_grads = zip(*sample_grads)
    sample_grads = [torch.stack(shards) for shards in sample_grads]
    return sample_grads

per_sample_grads = compute_sample_grads(data, targets)

sample_grads[0]model.conv1.weight 的逐样本梯度。model.conv1.weight.shape[32, 1, 3, 3];注意批次中每个样本都有一个梯度,总共有 64 个。

print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])

使用函数变换实现按样本梯度计算的高效方法

我们可以通过使用函数变换来高效地计算每个样本的梯度。

torch.func 函数变换 API 可以变换函数。我们的策略是定义一个计算损失值的函数,然后应用变换来构造一个计算每个样本梯度的函数。

我们将使用 torch.func.functional_call 函数将 nn.Module 视为一个函数。

首先,将 model 中的状态提取到两个字典中:参数和缓冲区。我们将分离它们,因为我们不会使用常规的 PyTorch 自动求导(例如 Tensor.backward()torch.autograd.grad)。

fromtorch.funcimport functional_call, vmap, grad

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

接下来,我们定义一个函数来计算模型在单个输入(而不是一批输入)情况下的损失。这个函数需要接收参数、输入和目标值,这一点很重要,因为我们将在它们之上进行变换。

注意 - 由于模型最初是设计用于处理批量的,因此我们将使用 torch.unsqueeze 来添加一个批次维度。

defcompute_loss(params, buffers, sample, target):
    batch = sample.unsqueeze(0)
    targets = target.unsqueeze(0)

    predictions = functional_call(model, (params, buffers), (batch,))
    loss = loss_fn(predictions, targets)
    return loss

现在,让我们使用grad变换来创建一个新函数,该函数计算相对于compute_loss第一个参数(即params)的梯度。

ft_compute_grad = grad(compute_loss)

ft_compute_grad 函数用于计算单个(样本,目标)对的梯度。我们可以使用 vmap 来使其计算整个批次的样本和目标的梯度。注意,in_dims=(None, None, 0, 0) 是因为我们希望将 ft_compute_grad 映射到数据和目标的第 0 维度,并对每个样本和目标使用相同的 params 和缓冲区。

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

最后,让我们使用转换后的函数来计算每个样本的梯度:

ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)

我们可以再次确认,使用 gradvmap 得到的结果与手动逐个处理的结果是否一致:

for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
    assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)

简要说明:vmap 在转换函数类型方面存在一些限制。最适合转换的函数是纯函数:即输出仅由输入决定,且没有副作用(例如数据修改)的函数。vmap 无法处理任意 Python 数据结构中的修改,但它能够处理许多 PyTorch 的原位操作。

性能对比

想了解 vmap 的性能表现如何吗?

目前在新款 GPU(如 A100,Ampere 架构)上取得了最佳结果,在这个示例中我们看到了高达 25 倍的加速。以下是我们构建机器上的一些结果:

defget_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
    second_res = second.times[0]
    first_res = first.times[0]

    gain = (first_res-second_res)/first_res
    if gain < 0: gain *=-1
    final_gain = gain*100

    print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")

fromtorch.utils.benchmarkimport Timer

without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)

print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')

get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f7fb57a7820>
compute_sample_grads(data, targets)
  100.25 ms
  1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f7ef1fb7fa0>
ft_compute_sample_grad(params, buffers, data, targets)
  8.59 ms
  1 measurement, 100 runs , 1 thread
Performance delta: 1067.2884 percent improvement with vmap

在 PyTorch 中,还有其他经过优化的解决方案(例如在 https://github.com/pytorch/opacus 中)可以计算逐样本梯度,这些方案也比简单方法表现更好。但有趣的是,结合 vmapgrad 可以带来显著的加速效果。

一般来说,使用 vmap 进行向量化应该比在 for 循环中运行函数更快,并且与手动批处理相当。不过也有一些例外情况,例如如果我们尚未为特定操作实现 vmap 规则,或者底层内核未针对旧硬件(如 GPU)进行优化。如果您遇到这些情况,请在 GitHub 上提交 issue 告知我们。

本页目录