• Tutorials >
  • 通过Flask以REST API的形式部署PyTorch模型
Shortcuts

通过Flask以REST API的形式部署PyTorch模型

Created On: Jul 03, 2019 | Last Updated: Jun 05, 2025 | Last Verified: Nov 05, 2024

作者Avinash Sajjanshetty

在本教程中,我们将使用Flask部署一个PyTorch模型,并为模型推理公开一个REST API。具体来说,我们将部署一个预训练的DenseNet 121模型,该模型可用于检测图像。

小技巧

此处使用的所有代码均根据MIT许可证发布,并可在 Github 上找到。

这是有关在生产环境中部署PyTorch模型系列的第一个教程。以这种方式使用Flask是开始服务您的PyTorch模型的最简单方式,但对于有高性能需求的用例,此方法可能不适合。针对这类需求:

API 定义

我们将首先定义 API 端点以及请求和响应类型。我们的 API 端点将位于 /predict,它接收带有包含图像的 file 参数的 HTTP POST 请求。响应将是包含预测结果的 JSON 响应:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

依赖

通过运行以下命令安装所需的依赖项:

pip install Flask==2.0.1 torchvision==0.10.0

简单的 Web 服务器

以下是从 Flask 官方文档中截取的一个简单 Web 服务器

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

我们还将更改响应类型,使其返回包含 ImageNet 类 ID 和名称的 JSON 响应。更新后的 app.py 文件如下:

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推断

在接下来的部分中,我们将重点编写推断代码。这将包括两个步骤:准备图像以便可以馈送到 DenseNet,然后编写代码从模型中获取实际预测。

准备图像

DenseNet 模型要求图像是大小为 224 x 224 的 3 通道 RGB 图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。可以在 这里 阅读更多相关信息。

我们将使用来自 torchvision 库的 transforms,并构建一个转换管道,以按需转换我们的图像。您可以在 这里 阅读更多关于转换的信息。

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法接收字节形式的图像数据,应用一系列转换并返回一个张量。为了测试上述方法,以字节模式读取图像文件(首先将 ../_static/img/sample_file.jpeg 替换为您计算机上的实际文件路径),并查看是否可以成功返回张量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

预测

现在我们将使用一个预训练的 DenseNet 121 模型来预测图像类别。我们将使用来自 torchvision 库的模型,加载模型并进行推断。尽管本示例中我们使用的是预训练的模型,但您可以对自己的模型使用相同的方式。有关加载您自己的模型的更多信息,请参阅此 教程

from torchvision import models

# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

张量 y_hat 将包含预测类别 ID 的索引。然而,我们需要一个人类可读的类别名称。为此,我们需要从类别 ID 到名称的映射。下载 此文件 并保存为 imagenet_class_index.json,记住保存位置(如果您按照本教程的具体步骤操作,请将其保存到 tutorials/_static 文件夹中)。此文件包含 ImageNet 类别 ID 与名称的映射。我们将加载此 JSON 文件,并通过预测索引获取类别名称。

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用 imagenet_class_index 字典之前,我们首先需要将张量值转换为字符串值,因为 imagenet_class_index 字典中的键是字符串。我们将测试上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

您应该会得到如下响应:

['n02124075', 'Egyptian_cat']

数组中的第一个元素是 ImageNet 类别 ID,第二个元素是人类可读的名称。

将模型集成到我们的 API 服务器中

在最后一部分中,我们将把模型添加到我们的 Flask API 服务器中。由于我们的 API 服务器需要接收图像文件,我们将更新 predict 方法以从请求中读取文件:

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()
FLASK_ENV=development FLASK_APP=app.py flask run

使用库发送一个 POST 请求到我们的应用程序:

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

打印 resp.json() 现在会显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

我们编写的服务器非常简单,可能无法满足您的生产应用的所有需求。因此,以下是一些可以改进的地方:

  • 端点 /predict 假设请求中总是会包含一个图像文件。这种情况可能并不适用于所有请求。用户可能会发送包含不同参数的图像,或根本不发送图像。

  • 用户也可能会发送非图像类型的文件。由于我们没有处理错误,这可能会导致服务器崩溃。添加一个显式的错误处理路径,可以在出现错误输入时抛出异常,更好地进行处理。

  • 即使模型能够识别大量类别的图像,它也可能无法识别所有图像。可以增强实现,以处理模型无法从图像中识别任何内容的情况。

  • 我们以开发模式运行 Flask 服务器,这并不适合生产环境中的部署。您可以查看 此教程,了解如何在生产环境中部署 Flask 服务器。

  • 您还可以通过创建一个带有表单的页面为服务器添加用户界面,该表单接收图像并显示预测结果。

  • 在本教程中,我们仅展示如何构建一个服务,以一次返回单张图像的预测结果。您可以修改此服务,以一次返回多张图像的预测结果。此外,service-streamer 库会自动将请求排队并采样为小批量进行处理,以供模型使用。您可以查看 此教程

  • 最后,我们鼓励您查看页面顶部提供的其他 PyTorch 模型部署教程。

**脚本的总运行时间:**(0分钟0.000秒)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源