在 C++ 中加载 TorchScript 模型
TorchScript 已不再积极开发。
顾名思义,PyTorch 的主要接口是 Python 编程语言。虽然 Python 在许多需要动态性和迭代便利性的场景中是一种合适且首选的语言,但同样存在许多情况下 Python 的这些特性并不理想。其中一个适用后者的情况是生产环境——这里是低延迟和严格部署要求的领域。在生产场景中,C++ 通常是首选语言,即使只是将其绑定到另一种语言(如 Java、Rust 或 Go)中。以下段落将概述 PyTorch 提供的方法,帮助您将现有的 Python 模型转换为可以在纯 C++ 环境中加载和执行的序列化表示,而无需依赖 Python。
第一步:将您的 PyTorch 模型转换为 Torch Script
PyTorch 模型从 Python 到 C++ 的旅程是通过 Torch Script 实现的,Torch Script 是一种 PyTorch 模型的表示形式,能够被 Torch Script 编译器理解、编译和序列化。如果您从现有的使用原生“eager” API 编写的 PyTorch 模型开始,首先需要将模型转换为 Torch Script。在大多数情况下,如下文所述,这只需要很少的工作量。如果您已经拥有一个 Torch Script 模块,可以跳过本教程的下一部分。
将 PyTorch 模型转换为 Torch Script 有两种方法。第一种称为 追踪(tracing),这是一种通过使用示例输入评估模型一次并记录这些输入在模型中的流动来捕获模型结构的机制。这种方法适用于较少使用控制流的模型。第二种方法是在模型中添加显式注解,告知 Torch Script 编译器它可以直接解析和编译您的模型代码,但需遵循 Torch Script 语言的约束。
您可以在官方 Torch Script 参考 中找到这两种方法的完整文档,以及关于使用哪种方法的进一步指导。
通过追踪转换为 Torch Script
要通过跟踪将 PyTorch 模型转换为 Torch Script,您需要将模型实例以及示例输入传递给 torch.jit.trace
函数。这将生成一个 torch.jit.ScriptModule
对象,其中嵌入了模型评估的跟踪信息,存储在模块的 forward
方法中:
importtorch
importtorchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
追踪后的 ScriptModule
现在可以与常规的 PyTorch 模块一样进行评估:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
通过注解转换为 Torch Script
在某些情况下,例如当您的模型采用特定形式的控制流时,您可能希望直接使用 Torch Script 编写模型,并相应地注释您的模型。例如,假设您有以下普通的 PyTorch 模型:
importtorch
classMyModule(torch.nn.Module):
def__init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
defforward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
由于此模块的 forward
方法使用了依赖于输入的控制流,因此不适用于跟踪。相反,我们可以将其转换为 ScriptModule
。为了将模块转换为 ScriptModule
,需要使用 torch.jit.script
编译该模块,如下所示:
classMyModule(torch.nn.Module):
def__init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
defforward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
如果您需要在 nn.Module
中排除某些方法,因为这些方法使用了 TorchScript 尚未支持的 Python 特性,您可以使用 @torch.jit.ignore
注解来标记这些方法。
sm
是一个已经准备好进行序列化的 ScriptModule
实例。
步骤 2:将您的脚本模块序列化到文件中
当您获得一个 ScriptModule
后,无论是通过追踪还是注解 PyTorch 模型,您都可以将其序列化并保存到文件中。之后,您可以在 C++ 中从这个文件加载该模块并执行,而无需依赖 Python。假设我们想序列化之前在追踪示例中展示的 ResNet18
模型。要进行此序列化操作,只需在模块上调用 save 并传入文件名即可:
traced_script_module.save("traced_resnet_model.pt")
这将在您的工作目录中生成一个 traced_resnet_model.pt
文件。如果您还想序列化 sm
,可以调用 sm.save("my_module_model.pt")
。至此,我们已经正式离开了 Python 的领域,准备进入 C++ 的世界。
第三步:在 C++ 中加载您的脚本模块
要在 C++ 中加载您序列化的 PyTorch 模型,您的应用程序必须依赖 PyTorch C++ API —— 也称为 LibTorch。LibTorch 发行版包含一组共享库、头文件和 CMake 构建配置文件。虽然 CMake 并不是依赖 LibTorch 的必需条件,但它是推荐的方法,并且在未来会得到良好的支持。在本教程中,我们将使用 CMake 和 LibTorch 构建一个简单的 C++ 应用程序,该应用程序仅加载并执行一个序列化的 PyTorch 模型。
一个最小的 C++ 应用程序
让我们从讨论加载模块的代码开始。以下代码已经可以实现:
#include<torch/script.h> // One-stop header.
#include<iostream>
#include<memory>
intmain(intargc,constchar*argv[]){
if(argc!=2){
std::cerr<<"usage: example-app <path-to-exported-script-module>\n";
return-1;
}
torch::jit::script::Modulemodule;
try{
// Deserialize the ScriptModule from a file using torch::jit::load().
module=torch::jit::load(argv[1]);
}
catch(constc10::Error&e){
std::cerr<<"error loading the model\n";
return-1;
}
std::cout<<"ok\n";
}
<torch/script.h>
头文件包含了运行示例所需的 LibTorch 库中的所有相关引用。我们的应用程序接受一个序列化的 PyTorch ScriptModule
文件的路径作为唯一的命令行参数,然后使用 torch::jit::load()
函数对该模块进行反序列化,该函数以该文件路径作为输入。作为返回,我们会得到一个 torch::jit::script::Module
对象。稍后我们将探讨如何执行它。
依赖 LibTorch 并构建应用程序
假设我们将上述代码存储在一个名为 example-app.cpp
的文件中。构建它的最小 CMakeLists.txt
文件可以像下面这样简单:
cmake_minimum_required(VERSION3.0FATAL_ERROR)
project(custom_ops)
find_package(TorchREQUIRED)
add_executable(example-appexample-app.cpp)
target_link_libraries(example-app"${TORCH_LIBRARIES}")
set_property(TARGETexample-appPROPERTYCXX_STANDARD17)
构建示例应用程序所需的最后一步是获取 LibTorch 发行版。您可以从 PyTorch 网站的下载页面获取最新的稳定版本。如果您下载并解压最新的压缩包,应该会得到一个包含以下目录结构的文件夹:
libtorch/
bin/
include/
lib/
share/
-
lib/
文件夹包含了您必须链接的共享库, -
include/
文件夹包含了您的程序需要包含的头文件, -
share/
文件夹包含了必要的 CMake 配置,以便启用上面提到的简单find_package(Torch)
命令。
在 Windows 上,调试版本和发布版本的 ABI 不兼容。如果您计划在调试模式下构建项目,请尝试使用 LibTorch 的调试版本。同时,请确保在下面的
cmake --build .
行中指定正确的配置。
最后一步是构建应用程序。为此,假设我们的示例目录结构如下:
example-app/
CMakeLists.txt
example-app.cpp
现在我们可以运行以下命令,在 example-app/
文件夹内构建应用程序:
mkdirbuild
cdbuild
cmake-DCMAKE_PREFIX_PATH=/path/to/libtorch..
cmake--build.--configRelease
其中 /path/to/libtorch
应为解压后的 LibTorch 分发包的完整路径。如果一切顺利,它应该看起来像这样:
root@4b5a67132e81:/example-app#mkdirbuild
root@4b5a67132e81:/example-app#cdbuild
root@4b5a67132e81:/example-app/build#cmake-DCMAKE_PREFIX_PATH=/path/to/libtorch..
*-TheCcompileridentificationisGNU5.4.0
*-TheCXXcompileridentificationisGNU5.4.0
*-CheckforworkingCcompiler:/usr/bin/cc
*-CheckforworkingCcompiler:/usr/bin/cc--works
*-DetectingCcompilerABIinfo
*-DetectingCcompilerABIinfo-done
*-DetectingCcompilefeatures
*-DetectingCcompilefeatures-done
*-CheckforworkingCXXcompiler:/usr/bin/c++
*-CheckforworkingCXXcompiler:/usr/bin/c++--works
*-DetectingCXXcompilerABIinfo
*-DetectingCXXcompilerABIinfo-done
*-DetectingCXXcompilefeatures
*-DetectingCXXcompilefeatures-done
*-Lookingforpthread.h
*-Lookingforpthread.h-found
*-Lookingforpthread_create
*-Lookingforpthread_create-notfound
*-Lookingforpthread_createinpthreads
*-Lookingforpthread_createinpthreads-notfound
*-Lookingforpthread_createinpthread
*-Lookingforpthread_createinpthread-found
*-FoundThreads:TRUE
*-Configuringdone
*-Generatingdone
*-Buildfileshavebeenwrittento:/example-app/build
root@4b5a67132e81:/example-app/build#make
Scanningdependenciesoftargetexample-app
[50%]BuildingCXXobjectCMakeFiles/example-app.dir/example-app.cpp.o
[100%]LinkingCXXexecutableexample-app
[100%]Builttargetexample-app
如果我们提供之前创建的已跟踪的 ResNet18
模型 traced_resnet_model.pt
的路径给生成的 example-app
可执行文件,我们应该会得到一个友好的“ok”响应。请注意,如果尝试使用 my_module_model.pt
运行此示例,您将收到一个错误提示,指出您的输入形状不兼容。my_module_model.pt
期望的是 1D 而不是 4D 的输入。
root@4b5a67132e81:/example-app/build#./example-app<path_to_model>/traced_resnet_model.pt
ok
步骤 4: 在 C++ 中执行脚本模块
成功地在 C++ 中加载了我们序列化的 ResNet18
模型,现在只需几行代码就可以执行它!让我们将这些代码添加到 C++ 应用程序的 main()
函数中:
// Create a vector of inputs.
std::vector<torch::jit::IValue>inputs;
inputs.push_back(torch::ones({1,3,224,224}));
// Execute the model and turn its output into a tensor.
at::Tensoroutput=module.forward(inputs).toTensor();
std::cout<<output.slice(/*dim=*/1,/*start=*/0,/*end=*/5)<<'\n';
前两行设置了模型的输入。我们创建了一个 torch::jit::IValue
的向量(script::Module
方法接受和返回的类型擦除值类型),并添加了一个输入。为了创建输入张量,我们使用了 torch::ones()
,这相当于 C++ API 中的 torch.ones
。然后我们运行 script::Module
的 forward
方法,将我们创建的输入向量传递给它。作为返回,我们获得了一个新的 IValue
,通过调用 toTensor()
将其转换为张量。
要了解更多关于
torch::ones
等函数以及 PyTorch C++ API 的信息,请参阅其官方文档:https://pytorch.org/cppdocs。PyTorch C++ API 提供了与 Python API 几乎一致的功能,使您可以像在 Python 中一样进一步操作和处理张量。
在最后一行,我们打印了输出的前五个条目。由于我们在本教程之前已经在 Python 中向模型提供了相同的输入,理想情况下我们应该看到相同的输出。让我们通过重新编译应用程序并使用相同的序列化模型来运行它,以验证这一点:
root@4b5a67132e81:/example-app/build#make
Scanningdependenciesoftargetexample-app
[50%]BuildingCXXobjectCMakeFiles/example-app.dir/example-app.cpp.o
[100%]LinkingCXXexecutableexample-app
[100%]Builttargetexample-app
root@4b5a67132e81:/example-app/build#./example-apptraced_resnet_model.pt
*0.2698-0.03810.4023-0.3010-0.0448
[Variable[CPUFloatType]{1,5}]
作为参考,之前在 Python 中的输出是:
tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
看起来匹配得很好!
要将模型移动到 GPU 内存中,您可以编写
model.to(at::kCUDA);
。确保模型的输入也位于 CUDA 内存中,可以通过调用tensor.to(at::kCUDA)
来实现,该调用将返回一个位于 CUDA 内存中的新张量。
第5步:获取帮助并探索API
本教程希望为您提供了关于 PyTorch 模型从 Python 到 C++ 路径的总体理解。通过本教程中描述的概念,您应该能够从一个普通的、“急切”模式的 PyTorch 模型,逐步转换到 Python 中编译的 ScriptModule
,再保存为磁盘上的序列化文件,最后在 C++ 中加载为可执行的 script::Module
,从而完成整个流程。
当然,还有许多概念我们未涵盖。例如,您可能会希望使用 C++ 或 CUDA 实现的自定义操作符扩展您的 ScriptModule
,并在纯 C++ 生产环境中加载的 ScriptModule
中执行该自定义操作符。好消息是:这是可行的,并且得到了良好的支持!目前,您可以探索 此 文件夹中的示例,我们很快将发布相关教程。在此期间,以下链接可能会对您有所帮助:
-
Torch Script 参考文档: https://pytorch.org/docs/master/jit.html
-
PyTorch C++ API 文档: https://pytorch.org/cppdocs/
-
PyTorch Python API 文档: https://pytorch.org/docs/
一如既往,如果您遇到任何问题或有疑问,可以使用我们的 论坛 或 GitHub issues 与我们联系。