优化 Vision Transformer 模型以进行部署
Vision Transformer 模型将最前沿的基于注意力机制的 Transformer 模型(最初在自然语言处理领域引入,用于实现各种最先进的成果)应用于计算机视觉任务。Facebook 的高效图像 Transformer 模型 DeiT 是一个基于 ImageNet 训练的图像分类 Vision Transformer 模型。
在本教程中,我们将首先介绍 DeiT 是什么以及如何使用它,然后逐步讲解如何在 iOS 和 Android 应用中完成模型的脚本编写、量化、优化和使用。我们还将比较量化优化后的模型与未量化未优化模型的性能,并展示在模型应用过程中进行量化和优化所带来的好处。
什么是 DeiT
卷积神经网络(CNNs)自2012年深度学习兴起以来,一直是图像分类的主要模型,但CNNs通常需要数亿张图像进行训练才能达到最先进(SOTA)的结果。DeiT是一种视觉Transformer模型,它在执行图像分类任务时,仅需更少的数据和计算资源就能与领先的CNNs竞争,这得益于DeiT的两个关键组件:
- 数据增强,模拟在更大数据集上的训练;
- 原生蒸馏,使 Transformer 网络能够从 CNN 的输出中学习。
DeiT 展示了 Transformers 可以在数据和资源有限的情况下成功应用于计算机视觉任务。有关 DeiT 的更多详细信息,请参阅 仓库 和 论文。
使用 DeiT 进行图像分类
请查阅 DeiT 仓库中的 README.md
文件,了解如何使用 DeiT 进行图像分类的详细信息。若要进行快速测试,请先安装所需的依赖包:
pip install torch torchvision timm pandas requests
要在 Google Colab 中运行,请通过执行以下命令来安装依赖项:
!pip install timm pandas requests
然后运行以下脚本:
fromPILimport Image
importtorch
importtimm
importrequests
importtorchvision.transformsastransforms
fromtimm.data.constantsimport IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.6.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:
Importing from timm.models.registry is deprecated, please import via timm.models
/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:
Importing from timm.models.layers is deprecated, please import via timm.layers
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:
Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:
Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:
Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:
Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:
Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:
Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:
Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:
Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
0%| | 0.00/330M [00:00<?, ?B/s]
5%|5 | 16.9M/330M [00:00<00:01, 176MB/s]
11%|#1 | 37.1M/330M [00:00<00:01, 197MB/s]
17%|#7 | 57.5M/330M [00:00<00:01, 205MB/s]
24%|##3 | 77.9M/330M [00:00<00:01, 208MB/s]
30%|##9 | 98.2M/330M [00:00<00:01, 210MB/s]
36%|###5 | 119M/330M [00:00<00:01, 211MB/s]
42%|####2 | 139M/330M [00:00<00:00, 212MB/s]
48%|####8 | 160M/330M [00:00<00:00, 212MB/s]
54%|#####4 | 180M/330M [00:00<00:00, 213MB/s]
61%|###### | 200M/330M [00:01<00:00, 213MB/s]
67%|######6 | 221M/330M [00:01<00:00, 213MB/s]
73%|#######3 | 242M/330M [00:01<00:00, 214MB/s]
79%|#######9 | 262M/330M [00:01<00:00, 214MB/s]
86%|########5 | 282M/330M [00:01<00:00, 214MB/s]
92%|#########1| 303M/330M [00:01<00:00, 214MB/s]
98%|#########7| 324M/330M [00:01<00:00, 214MB/s]
100%|##########| 330M/330M [00:01<00:00, 211MB/s]
269
输出结果应为 269,根据 ImageNet 的类别索引与 标签文件 的映射关系,这对应的是 灰狼、Canis lupus
。
既然我们已经验证了可以使用 DeiT 模型对图像进行分类,接下来我们看看如何修改模型,使其能够在 iOS 和 Android 应用程序中运行。
脚本化 DeiT
要在移动设备上使用该模型,我们首先需要对模型进行脚本化处理。有关快速概述,请参阅脚本化和优化实用技巧。运行以下代码,将上一步中使用的 DeiT 模型转换为可在移动设备上运行的 TorchScript 格式。
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main
脚本化模型文件 fbdeit_scripted.pt
的大小约为 346MB,已生成。
量化 DeiT
为了在保持推理精度基本不变的情况下显著减小训练后模型的大小,可以将量化技术应用于模型。由于 DeiT 中使用了 Transformer 模型,我们可以轻松地对模型应用动态量化,因为动态量化对 LSTM 和 Transformer 模型效果最佳(更多详情请参见这里)。
现在运行以下代码:
# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/observer.py:229: UserWarning:
Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
这将生成脚本化和量化后的模型 fbdeit_quantized_scripted.pt
,其大小约为 89MB,相比未量化模型的大小 346MB 减少了 74%!
您可以使用 scripted_quantized_model
来生成相同的推理结果:
out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed
269
优化 DeiT
在移动设备上使用量化和脚本化模型之前的最后一步是对其进行优化:
fromtorch.utils.mobile_optimizerimport optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
生成的 fbdeit_optimized_scripted_quantized.pt
文件大小与量化、脚本化但未优化的模型大致相同。推理结果保持不变。
out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed
269
使用 Lite 解释器
为了了解 Lite Interpreter 在模型大小缩减和推理速度提升方面的效果,让我们创建该模型的精简版本。
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
尽管轻量级模型的体积与非轻量级版本相当,但在移动设备上运行轻量级版本时,推理速度预计会有所提升。
推理速度对比
要查看四种模型(原始模型、脚本化模型、量化并脚本化的模型、优化并量化并脚本化的模型)的推理速度有何不同,请运行以下代码:
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
out = ptl(img)
print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 100.73ms
scripted model: 112.45ms
scripted & quantized model: 120.77ms
scripted & quantized & optimized model: 120.85ms
lite model: 119.98ms
在 Google Colab 上运行的结果如下:
originalmodel:1236.69ms
scriptedmodel:1226.72ms
scripted&quantizedmodel:593.19ms
scripted&quantized&optimizedmodel:598.01ms
litemodel:600.72ms
以下结果总结了每个模型的推理时间以及每个模型相对于原始模型的百分比减少量。
importpandasaspd
importnumpyasnp
df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
columns=['Inference Time', 'Reduction'])], axis=1)
print(df)
"""
Model Inference Time Reduction
0 original model 1236.69ms 0%
1 scripted model 1226.72ms 0.81%
2 scripted & quantized model 593.19ms 52.03%
3 scripted & quantized & optimized model 598.01ms 51.64%
4 lite model 600.72ms 51.43%
"""
Model ... Reduction
0 original model ... 0%
1 scripted model ... -11.63%
2 scripted & quantized model ... -19.90%
3 scripted & quantized & optimized model ... -19.97%
4 lite model ... -19.11%
[5 rows x 3 columns]
'\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n'