Shortcuts

使用 Flask 部署

Created On: May 04, 2020 | Last Updated: Sep 15, 2021 | Last Verified: Not Verified

在此教程中,您将学习:

  • 如何将训练后的 PyTorch 模型封装在 Flask 容器中并通过 Web API 进行暴露

  • 如何将 Web 请求解析为供模型使用的 PyTorch 张量

  • 如何将模型的输出封装为 HTTP 响应

要求

您需要一个安装了以下包(及其依赖项)的 Python 3 环境:

  • PyTorch 1.5

  • TorchVision 0.6.0

  • Flask 1.1

可选:若需获得一些支持性文件,您需要安装 git。

安装 PyTorch 和 TorchVision 的指南可在 pytorch.org 上获得。安装 Flask 的指南可在 Flask 站点 上获得。

什么是 Flask?

Flask 是一个用 Python 编写的轻量级 Web 服务器。它为您提供了一种方便的方法,能够快速设置 Web API,以从训练后的 PyTorch 模型进行预测,无论是直接使用还是作为较大系统中的 Web 服务。

设置和支持文件

我们将创建一个 Web 服务,它接收图像并将其映射到 ImageNet 数据集的 1000 个类别之一。为此,您需要一个用于测试的图像文件。可选地,您还可以获得一个文件,该文件可将模型输出的类别索引映射为人类可读的类别名称。

选项 1:快速获取两个文件

您可以通过检出 TorchServe 仓库并将它们复制到您的工作目录中快速获取两个支持文件。(注意:此教程与 TorchServe 没有依赖关系——这只是获取文件的快捷方式。)从您的终端提示符发出以下命令:

git clone https://github.com/pytorch/serve
cp serve/examples/image_classifier/kitten.jpg .
cp serve/examples/image_classifier/index_to_name.json .

然后您就能获得它们!

选项 2:使用您自己的图像

Flask 服务中的 index_to_name.json 文件是可选项。您可以使用自己的图像测试服务——只需确保它是 3 色 JPEG。

构建您的 Flask 服务

此教程最后展示了 Flask 服务的完整 Python 脚本;您可以将其复制粘贴到自己的 app.py 文件中。下面我们将逐步查看个别部分,以明确它们的功能。

导入

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

依次为:

  • 我们将使用来自 torchvision.models 的预训练 DenseNet 模型

  • torchvision.transforms 包含用于操作图像数据的工具

  • Pillow(PIL)用于首次加载图像文件

  • 当然还需要 flask 的类

预处理

def transform_image(infile):
    input_transforms = [transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)
    timg = my_transforms(image)
    timg.unsqueeze_(0)
    return timg

Web 请求提供了一个图像文件,但我们模型需要一个形状为 (N, 3, 224, 224) 的 PyTorch 张量,其中 N 是输入批次的项数。(我们只会有一个批次大小为 1。)首先,我们创建了一组 TorchVision 变换,这些变换可重设图像尺寸并裁剪图像、将图像转换为张量,然后归一化张量中的值。(有关此归一化的更多信息,请参阅 torchvision.models_ 的文档。)

之后,我们打开该文件并应用这些变换。变换返回了一个形状为 (3, 224, 224) 的张量——224x224 图像的 3 个颜色通道。因为我们需要将这张单独的图像变成一个批次,所以我们使用 unsqueeze_(0) 来在原地修改张量,并添加一个新的第一个维度。张量包含相同的数据,但现在形状为 (1, 3, 224, 224)。

一般来说,即使您未处理图像数据,也需要将 HTTP 请求中的输入转换为 PyTorch 可以处理的张量。

推理

def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)
    _, y_hat = outputs.max(1)
    prediction = y_hat.item()
    return prediction

推理本身是最简单的部分:当我们将输入张量传递给模型时,它会返回一个值张量,这些值代表模型估算的图像属于特定类别的可能性。max() 调用找到可能性值最大的类别,并返回该值及其对应的 ImageNet 类索引。最后,我们使用 item() 调用从包含类索引的张量中提取该索引并返回。

后处理

def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name

render_prediction() 方法将预测的类索引映射为人类可读的类标签。在从模型获得预测后,通常会进行后处理,以使预测适合人类使用或适合由另一个软件使用。

运行完整的 Flask 应用程序

将以下内容粘贴到一个名为 app.py 的文件中:

import io
import json
import os

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
model = models.densenet121(pretrained=True)               # Trained on 1000 classes from ImageNet
model.eval()                                              # Turns off autograd



img_class_map = None
mapping_file_path = 'index_to_name.json'                  # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
    with open (mapping_file_path) as f:
        img_class_map = json.load(f)



# Transform input into the form our model expects
def transform_image(infile):
    input_transforms = [transforms.Resize(255),           # We use multiple TorchVision transforms to ready the image
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],       # Standard normalization for ImageNet model input
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)                            # Open the image file
    timg = my_transforms(image)                           # Transform PIL image to appropriately-shaped PyTorch tensor
    timg.unsqueeze_(0)                                    # PyTorch models expect batched input; create a batch of 1
    return timg


# Get a prediction
def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)                 # Get likelihoods for all ImageNet classes
    _, y_hat = outputs.max(1)                             # Extract the most likely class
    prediction = y_hat.item()                             # Extract the int value from the PyTorch tensor
    return prediction

# Make the prediction human-readable
def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name


@app.route('/', methods=['GET'])
def root():
    return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        if file is not None:
            input_tensor = transform_image(file)
            prediction_idx = get_prediction(input_tensor)
            class_id, class_name = render_prediction(prediction_idx)
            return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

从终端提示符启动服务器,请发出以下命令:

FLASK_APP=app.py flask run

默认情况下,您的 Flask 服务器在端口 5000 上监听。当服务器正在运行时,打开另一个终端窗口,并测试您的新推理服务器:

curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "[email protected]"

如果一切设置正确,您应该收到类似以下的响应:

{"class_id":285,"class_name":"Egyptian_cat"}

重要资源

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源