PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

模型集成

本教程演示了如何使用 torch.vmap 对模型集成进行向量化处理。

什么是模型集成?

模型集成将多个模型的预测结果结合在一起。传统方法是分别对某些输入运行每个模型,然后将预测结果进行组合。然而,如果您运行的是具有相同架构的模型,则可以使用 torch.vmap 将它们结合在一起。vmap 是一个函数变换,它可以将函数映射到输入张量的维度上。它的一个用例是通过向量化消除循环并加速它们。

让我们通过一个简单的 MLP 集成来演示如何实现这一点。

本教程需要 PyTorch 2.0.0 或更高版本。

importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
torch.manual_seed(0)

# Here's a simple MLP
classSimpleMLP(nn.Module):
    def__init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    defforward(self, x):
        x = x.flatten(1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

让我们生成一批虚拟数据,并假设我们正在处理 MNIST 数据集。因此,虚拟图像的尺寸为 28x28,我们有一个大小为 64 的小批次数据。此外,假设我们想要结合来自 10 个不同模型的预测结果。

device = 'cuda'
num_models = 10

data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)

models = [SimpleMLP().to(device) for _ in range(num_models)]

我们有几个生成预测的选项。也许我们想为每个模型提供不同的随机小批量数据。或者,我们可能希望将相同的小批量数据通过每个模型运行(例如,如果我们正在测试不同模型初始化的效果)。

选项1:每个模型使用不同的小批量数据

minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]

选项 2:相同的 mini-batch

minibatch = data[0]
predictions2 = [model(minibatch) for model in models]

使用 vmap 向量化集成模型

让我们使用 vmap 来加速 for 循环。我们首先需要准备好模型以便与 vmap 一起使用。

首先,让我们通过堆叠每个参数将模型的状态组合在一起。例如,model[i].fc1.weight 的形状是 [784, 128];我们将堆叠这 10 个模型中的每个 .fc1.weight,以生成一个形状为 [10, 784, 128] 的更大的权重。

PyTorch 提供了 torch.func.stack_module_state 这个便捷函数来完成这个操作。

fromtorch.funcimport stack_module_state

params, buffers = stack_module_state(models)

接下来,我们需要定义一个供 vmap 使用的函数。该函数应接收参数、缓冲区和输入,并使用这些参数、缓冲区和输入来运行模型。我们将使用 torch.func.functional_call 来辅助实现这一功能:

fromtorch.funcimport functional_call
importcopy

# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')

deffmodel(params, buffers, x):
    return functional_call(base_model, (params, buffers), (x,))

选项1:使用不同的minibatch为每个模型获取预测。

默认情况下,vmap会将函数映射到传入函数的所有输入的第一个维度上。在使用stack_module_state之后,每个params和buffers在前面都会有一个大小为‘num_models’的额外维度,而minibatches也会有一个大小为‘num_models’的维度。

print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension

assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'

fromtorchimport vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)

# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
[10, 10, 10, 10, 10, 10]

选项2:使用相同的小批量数据进行预测。

vmap 有一个 in_dims 参数,用于指定要映射的维度。通过使用 None,我们告诉 vmap 我们希望相同的小批量数据应用于所有 10 个模型。

predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)

assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)

简要说明:vmap 在转换函数类型方面存在一些限制。最适合转换的函数是纯函数:即输出仅由输入决定且没有副作用(例如数据突变)的函数。vmap 无法处理任意 Python 数据结构的突变,但它能够处理许多 PyTorch 的原位操作。

性能

对性能数据感到好奇吗?以下是具体的数据表现。

fromtorch.utils.benchmarkimport Timer
without_vmap = Timer(
    stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
    globals=globals())
with_vmap = Timer(
    stmt="vmap(fmodel)(params, buffers, minibatches)",
    globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fa06fa16530>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
  2.44 ms
  1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7fa06fa168f0>
vmap(fmodel)(params, buffers, minibatches)
  823.44 us
  1 measurement, 100 runs , 1 thread

使用 vmap 可以显著加速!

一般来说,使用 vmap 进行向量化应该比在 for 循环中运行函数更快,并且与手动批处理的性能相当。不过也有一些例外情况,例如如果我们尚未为某个特定操作实现 vmap 规则,或者底层内核未针对较旧的硬件(如 GPU)进行优化。如果您遇到这些情况,请通过在 GitHub 上提交问题告知我们。

本页目录