PyTorch 入门指南
学习 PyTorch
图像和视频
音频
后端
强化学习
在生产环境中部署 PyTorch 模型
Profiling PyTorch
代码变换与FX
前端API
扩展 PyTorch
模型优化
并行和分布式训练
边缘端的 ExecuTorch
推荐系统
多模态

使用 Flask 通过 REST API 在 Python 中部署 PyTorch

作者: Avinash Sajjanshetty

在本教程中,我们将使用 Flask 部署一个 PyTorch 模型,并暴露一个用于模型推理的 REST API。具体来说,我们将部署一个预训练的 DenseNet 121 模型,用于检测图像。

此处使用的所有代码均以 MIT 许可证发布,并可在 Github 上获取。

这是关于在生产环境中部署 PyTorch 模型系列教程的第一部分。使用 Flask 是开始提供 PyTorch 模型服务的最简单方式,但对于需要高性能的场景,这种方式并不适用。针对这种情况:

API 定义

我们将首先定义我们的API端点、请求和响应类型。我们的API端点将位于/predict,它接收HTTP POST请求,请求中包含一个file参数,该参数包含图像。响应将是包含预测结果的JSON响应:

{"class_id":"n02124075","class_name":"Egyptian_cat"}

依赖项

通过运行以下命令安装所需的依赖项:

pipinstallFlask==2.0.1torchvision==0.10.0

简单 Web 服务器

以下是一个简单的 Web 服务器示例,摘自 Flask 的文档

fromflaskimport Flask
app = Flask(__name__)


@app.route('/')
defhello():
    return 'Hello World!'

我们还将更改响应类型,使其返回包含ImageNet类别ID和名称的JSON响应。更新后的app.py文件现在如下:

fromflaskimport Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
defpredict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在接下来的章节中,我们将重点编写推理代码。这包括两个部分:首先,我们将准备图像,使其能够输入到 DenseNet 中;其次,我们将编写代码以从模型中获取实际的预测结果。

图像预处理

DenseNet 模型要求输入图像为 224 x 224 大小的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化处理。您可以在此处了解更多信息:这里

我们将使用 torchvision 库中的 transforms 并构建一个转换流水线,以便按要求转换我们的图像。您可以在此处了解更多关于转换的信息:这里

importio

importtorchvision.transformsastransforms
fromPILimport Image

deftransform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法接收字节形式的图像数据,应用一系列变换后返回一个张量。要测试该方法,请以字节模式读取图像文件(首先将 ../_static/img/sample_file.jpeg 替换为计算机上文件的实际路径),并查看是否返回了一个张量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

预测

现在我们将使用一个预训练的 DenseNet 121 模型来预测图像类别。我们将使用 torchvision 库中的一个模型,加载模型并进行推理。虽然在这个示例中我们使用了一个预训练模型,但您可以将同样的方法应用于您自己的模型。有关加载模型的更多信息,请参阅此教程

fromtorchvisionimport models

# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


defget_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

张量 y_hat 将包含预测类别的索引。然而,我们需要一个人类可读的类别名称。为此,我们需要一个从类别 ID 到名称的映射。请下载 此文件 并保存为 imagenet_class_index.json,记住保存位置(或者,如果您严格按照本教程的步骤操作,请将其保存在 tutorials/_static 目录下)。该文件包含了 ImageNet 类别 ID 到 ImageNet 类别名称的映射。我们将加载此 JSON 文件并获取预测索引对应的类别名称。

importjson

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

defget_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用 imagenet_class_index 字典之前,我们首先需要将张量值转换为字符串值,因为 imagenet_class_index 字典的键是字符串。我们将测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

您应该会收到类似如下的响应:

['n02124075', 'Egyptian_cat']

数组中的第一项是 ImageNet 类别 ID,第二项是人类可读的名称。

将模型集成到我们的 API 服务器中

在这最后一部分,我们将把我们的模型添加到 Flask API 服务器中。由于我们的 API 服务器需要接收图像文件,我们将更新 predict 方法以从请求中读取文件:

fromflaskimport request

@app.route('/predict', methods=['POST'])
defpredict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})
importio
importjson

fromtorchvisionimport models
importtorchvision.transformsastransforms
fromPILimport Image
fromflaskimport Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()


deftransform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


defget_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
defpredict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()
FLASK_ENV=developmentFLASK_APP=app.pyflaskrun

用于向我们的应用程序发送 POST 请求的库:

importrequests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印 resp.json() 现在将显示以下内容:

{"class_id":"n02124075","class_name":"Egyptian_cat"}

我们编写的服务器非常简单,可能无法满足生产应用程序的所有需求。因此,以下是您可以采取的一些改进措施:

  • 端点 /predict 假设请求中始终会包含一个图像文件。这可能并不适用于所有请求。用户可能会使用不同的参数发送图像,或者根本不发送图像。

  • 用户也可能发送非图像类型的文件。由于我们没有处理错误,这会导致服务器崩溃。添加一个明确的错误处理路径,抛出异常,可以让我们更好地处理无效输入。

  • 尽管模型能够识别大量图像类别,但它可能无法识别所有图像。增强实现以处理模型无法识别图像内容的情况。

  • 我们在开发模式下运行 Flask 服务器,这并不适合在生产环境中部署。您可以查看此教程,了解如何在生产环境中部署 Flask 服务器。

  • 您还可以通过创建一个带有表单的页面来添加用户界面,该表单可以接收图像并显示预测结果。查看类似项目的演示及其源代码

  • 在本教程中,我们仅展示了如何构建一个能够一次返回单个图像预测的服务。我们可以修改服务,使其能够同时返回多个图像的预测结果。此外,service-streamer 库会自动将请求排队并采样为可以输入模型的小批量数据。您可以查看此教程

  • 最后,我们鼓励您查看页面顶部链接的其他关于部署 PyTorch 模型的教程。

本页目录