备注
点击 此处 下载完整示例代码
优化视觉Transformer模型以供部署¶
Created On: Mar 15, 2021 | Last Updated: Jan 19, 2024 | Last Verified: Nov 05, 2024
视觉Transformer模型将最前沿的基于注意力的Transformer模型应用于自然语言处理,以完成所有种类的最先进(SOTA)任务,应用于计算机视觉任务。Facebook的高效数据图像Transformer模型(DeiT)是在ImageNet上训练、用于图像分类的视觉Transformer模型。
在本教程中,我们首先介绍什么是DeiT以及如何使用,然后全面解析脚本化、量化、优化以及在iOS和安卓应用中的使用步骤。我们还将比较量化优化模型与非量化非优化模型的性能,并展示沿着这些步骤对模型应用量化和优化的好处。
什么是DeiT¶
自2012年深度学习兴起以来,卷积神经网络(CNNs)一直是图像分类的主力模型,但实现SOTA结果通常需要数亿张图像进行训练。而DeiT是一种视觉Transformer模型,其需要较少的数据和计算资源进行训练,但可以与领先的CNNs竞争图像分类性能,这主要得益于DeiT的两个关键组件:
数据增强,它模拟了在更大数据集上的训练;
原生蒸馏,它使Transformer网络能够学习CNN的输出。
DeiT展示了Transformers可以成功应用于计算机视觉任务,即使访问数据和资源有限。有关DeiT的更多细节,请参阅`代码库 <https://github.com/facebookresearch/deit>`_ 和 论文。
使用DeiT进行图像分类¶
请按照DeiT代码库中的``README.md``获取有关如何使用DeiT进行图像分类的详细信息,或者为了快速测试,首先安装所需的软件包:
pip install torch torchvision timm pandas requests
如果在Google Colab中运行,请通过运行以下命令安装依赖项:
!pip install timm pandas requests
然后运行下面的脚本:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.7.0+cu126
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /home/user/.cache/torch/hub/main.zip
/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/timm/models/registry.py:4: FutureWarning:
Importing from timm.models.registry is deprecated, please import via timm.models
/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/timm/models/layers/__init__.py:48: FutureWarning:
Importing from timm.models.layers is deprecated, please import via timm.layers
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:
Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:
Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:
Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:
Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:
Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:
Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:
Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
/home/user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:
Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /home/user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
0%| | 0.00/330M [00:00<?, ?B/s]
0%| | 128k/330M [00:00<28:18, 204kB/s]
0%| | 256k/330M [00:00<16:34, 348kB/s]
0%| | 384k/330M [00:01<14:59, 384kB/s]
0%| | 768k/330M [00:01<06:15, 920kB/s]
0%| | 1.00M/330M [00:01<04:52, 1.18MB/s]
0%| | 1.38M/330M [00:01<03:27, 1.66MB/s]
1%| | 1.88M/330M [00:01<02:25, 2.37MB/s]
1%| | 2.62M/330M [00:01<01:37, 3.51MB/s]
1%|1 | 3.50M/330M [00:01<01:10, 4.87MB/s]
1%|1 | 4.75M/330M [00:01<00:49, 6.90MB/s]
2%|1 | 6.38M/330M [00:02<00:35, 9.49MB/s]
2%|2 | 7.62M/330M [00:02<00:32, 10.3MB/s]
3%|2 | 9.25M/330M [00:02<00:27, 12.1MB/s]
3%|3 | 10.5M/330M [00:02<00:27, 12.1MB/s]
4%|3 | 12.1M/330M [00:02<00:24, 13.4MB/s]
4%|4 | 13.5M/330M [00:02<00:25, 13.1MB/s]
5%|4 | 15.1M/330M [00:02<00:23, 13.9MB/s]
5%|4 | 16.5M/330M [00:02<00:24, 13.3MB/s]
5%|5 | 18.0M/330M [00:03<00:37, 8.66MB/s]
6%|6 | 20.6M/330M [00:03<00:31, 10.3MB/s]
7%|6 | 21.8M/330M [00:03<00:30, 10.5MB/s]
7%|6 | 22.9M/330M [00:03<00:32, 9.83MB/s]
7%|7 | 24.6M/330M [00:03<00:27, 11.6MB/s]
8%|7 | 25.9M/330M [00:03<00:26, 11.8MB/s]
8%|8 | 27.6M/330M [00:03<00:23, 13.3MB/s]
9%|8 | 29.0M/330M [00:03<00:24, 13.0MB/s]
9%|9 | 30.6M/330M [00:04<00:22, 14.1MB/s]
10%|9 | 32.1M/330M [00:04<00:23, 13.5MB/s]
10%|# | 33.5M/330M [00:04<00:35, 8.81MB/s]
10%|# | 34.6M/330M [00:04<00:33, 9.39MB/s]
11%|#1 | 36.4M/330M [00:04<00:33, 9.27MB/s]
11%|#1 | 37.9M/330M [00:04<00:29, 10.5MB/s]
12%|#1 | 39.4M/330M [00:05<00:26, 11.5MB/s]
12%|#2 | 40.9M/330M [00:05<00:24, 12.4MB/s]
13%|#2 | 42.4M/330M [00:05<00:23, 13.1MB/s]
13%|#3 | 43.9M/330M [00:05<00:22, 13.6MB/s]
14%|#3 | 45.2M/330M [00:05<00:21, 13.7MB/s]
14%|#4 | 46.6M/330M [00:05<00:34, 8.60MB/s]
15%|#4 | 48.4M/330M [00:05<00:29, 10.0MB/s]
15%|#5 | 49.6M/330M [00:06<00:31, 9.44MB/s]
16%|#5 | 51.5M/330M [00:06<00:25, 11.5MB/s]
16%|#6 | 52.9M/330M [00:06<00:25, 11.4MB/s]
17%|#6 | 54.8M/330M [00:06<00:21, 13.2MB/s]
17%|#7 | 56.2M/330M [00:06<00:22, 13.0MB/s]
18%|#7 | 58.1M/330M [00:06<00:21, 13.2MB/s]
18%|#8 | 59.9M/330M [00:06<00:19, 14.4MB/s]
19%|#8 | 61.4M/330M [00:06<00:20, 13.7MB/s]
19%|#9 | 63.1M/330M [00:07<00:19, 14.6MB/s]
20%|#9 | 64.6M/330M [00:07<00:29, 9.47MB/s]
20%|#9 | 65.9M/330M [00:07<00:27, 10.1MB/s]
20%|## | 67.1M/330M [00:07<00:29, 9.39MB/s]
21%|## | 68.6M/330M [00:07<00:25, 10.6MB/s]
21%|##1 | 70.1M/330M [00:07<00:23, 11.6MB/s]
22%|##1 | 71.6M/330M [00:07<00:21, 12.4MB/s]
22%|##2 | 73.1M/330M [00:08<00:20, 13.1MB/s]
23%|##2 | 74.6M/330M [00:08<00:19, 13.6MB/s]
23%|##3 | 76.1M/330M [00:08<00:19, 13.9MB/s]
23%|##3 | 77.5M/330M [00:08<00:19, 13.9MB/s]
24%|##3 | 78.9M/330M [00:08<00:18, 14.0MB/s]
24%|##4 | 80.2M/330M [00:08<00:28, 9.20MB/s]
25%|##4 | 82.0M/330M [00:08<00:26, 9.97MB/s]
25%|##5 | 83.1M/330M [00:09<00:27, 9.59MB/s]
26%|##5 | 84.2M/330M [00:09<00:26, 9.82MB/s]
26%|##5 | 85.8M/330M [00:09<00:23, 11.1MB/s]
26%|##6 | 87.2M/330M [00:09<00:21, 12.1MB/s]
27%|##6 | 88.8M/330M [00:09<00:19, 12.9MB/s]
27%|##7 | 90.2M/330M [00:09<00:18, 13.4MB/s]
28%|##7 | 91.6M/330M [00:09<00:18, 13.7MB/s]
28%|##8 | 93.1M/330M [00:09<00:17, 13.9MB/s]
29%|##8 | 94.5M/330M [00:09<00:17, 13.9MB/s]
29%|##9 | 95.9M/330M [00:09<00:17, 14.0MB/s]
29%|##9 | 97.2M/330M [00:10<00:17, 14.0MB/s]
30%|##9 | 98.6M/330M [00:10<00:26, 9.01MB/s]
30%|### | 100M/330M [00:10<00:27, 8.90MB/s]
31%|### | 102M/330M [00:10<00:23, 10.1MB/s]
31%|###1 | 103M/330M [00:10<00:25, 9.52MB/s]
32%|###1 | 105M/330M [00:10<00:22, 10.7MB/s]
32%|###2 | 106M/330M [00:11<00:20, 11.5MB/s]
33%|###2 | 107M/330M [00:11<00:19, 12.1MB/s]
33%|###2 | 109M/330M [00:11<00:28, 8.23MB/s]
33%|###3 | 110M/330M [00:11<00:33, 6.92MB/s]
34%|###3 | 111M/330M [00:12<00:48, 4.73MB/s]
34%|###4 | 112M/330M [00:12<00:42, 5.34MB/s]
34%|###4 | 113M/330M [00:12<00:50, 4.55MB/s]
35%|###4 | 115M/330M [00:12<00:34, 6.50MB/s]
35%|###5 | 116M/330M [00:12<00:35, 6.42MB/s]
36%|###5 | 117M/330M [00:13<00:28, 7.96MB/s]
36%|###5 | 119M/330M [00:13<00:24, 9.16MB/s]
36%|###6 | 120M/330M [00:13<00:21, 10.4MB/s]
37%|###6 | 122M/330M [00:13<00:19, 11.3MB/s]
37%|###7 | 123M/330M [00:13<00:18, 12.1MB/s]
38%|###7 | 124M/330M [00:13<00:16, 12.9MB/s]
38%|###8 | 126M/330M [00:13<00:16, 13.1MB/s]
39%|###8 | 127M/330M [00:13<00:25, 8.34MB/s]
39%|###8 | 129M/330M [00:14<00:21, 9.64MB/s]
39%|###9 | 130M/330M [00:14<00:23, 8.84MB/s]
40%|###9 | 132M/330M [00:14<00:20, 10.1MB/s]
40%|#### | 133M/330M [00:14<00:19, 10.6MB/s]
41%|#### | 134M/330M [00:14<00:17, 11.6MB/s]
41%|####1 | 136M/330M [00:14<00:16, 12.4MB/s]
42%|####1 | 137M/330M [00:14<00:15, 12.9MB/s]
42%|####1 | 138M/330M [00:14<00:15, 13.1MB/s]
42%|####2 | 140M/330M [00:15<00:14, 13.4MB/s]
43%|####2 | 141M/330M [00:15<00:14, 13.6MB/s]
43%|####3 | 143M/330M [00:15<00:14, 13.8MB/s]
44%|####3 | 144M/330M [00:15<00:13, 14.0MB/s]
44%|####4 | 146M/330M [00:15<00:13, 13.9MB/s]
45%|####4 | 147M/330M [00:15<00:13, 14.1MB/s]
45%|####4 | 148M/330M [00:15<00:13, 13.9MB/s]
45%|####5 | 150M/330M [00:15<00:13, 14.1MB/s]
46%|####5 | 151M/330M [00:15<00:13, 14.0MB/s]
46%|####6 | 152M/330M [00:15<00:13, 14.0MB/s]
47%|####6 | 154M/330M [00:16<00:12, 14.3MB/s]
47%|####7 | 156M/330M [00:16<00:12, 14.6MB/s]
48%|####7 | 157M/330M [00:16<00:12, 14.7MB/s]
48%|####7 | 158M/330M [00:16<00:12, 14.9MB/s]
48%|####8 | 160M/330M [00:16<00:11, 14.9MB/s]
49%|####8 | 162M/330M [00:16<00:11, 15.0MB/s]
49%|####9 | 163M/330M [00:16<00:11, 14.8MB/s]
50%|####9 | 164M/330M [00:16<00:11, 14.6MB/s]
50%|##### | 166M/330M [00:17<00:18, 9.39MB/s]
51%|##### | 167M/330M [00:17<00:16, 10.4MB/s]
51%|#####1 | 169M/330M [00:17<00:18, 9.08MB/s]
52%|#####1 | 170M/330M [00:17<00:16, 10.4MB/s]
52%|#####1 | 171M/330M [00:17<00:16, 9.86MB/s]
52%|#####2 | 173M/330M [00:17<00:14, 11.0MB/s]
53%|#####2 | 174M/330M [00:17<00:13, 11.9MB/s]
53%|#####3 | 176M/330M [00:17<00:12, 12.5MB/s]
54%|#####3 | 177M/330M [00:18<00:12, 12.9MB/s]
54%|#####4 | 179M/330M [00:18<00:11, 13.3MB/s]
54%|#####4 | 180M/330M [00:18<00:18, 8.64MB/s]
55%|#####4 | 181M/330M [00:18<00:16, 9.53MB/s]
55%|#####5 | 183M/330M [00:18<00:17, 8.81MB/s]
56%|#####5 | 184M/330M [00:18<00:15, 10.1MB/s]
56%|#####6 | 186M/330M [00:19<00:14, 10.7MB/s]
57%|#####6 | 187M/330M [00:19<00:12, 11.8MB/s]
57%|#####7 | 188M/330M [00:19<00:11, 12.6MB/s]
57%|#####7 | 190M/330M [00:19<00:11, 13.0MB/s]
58%|#####7 | 191M/330M [00:19<00:10, 13.3MB/s]
58%|#####8 | 193M/330M [00:19<00:10, 13.8MB/s]
59%|#####8 | 194M/330M [00:19<00:10, 13.9MB/s]
59%|#####9 | 196M/330M [00:19<00:10, 14.1MB/s]
60%|#####9 | 197M/330M [00:19<00:09, 14.0MB/s]
60%|###### | 198M/330M [00:20<00:21, 6.54MB/s]
61%|###### | 200M/330M [00:20<00:18, 7.56MB/s]
61%|###### | 201M/330M [00:20<00:15, 8.68MB/s]
61%|######1 | 203M/330M [00:20<00:14, 9.50MB/s]
62%|######1 | 204M/330M [00:20<00:12, 10.5MB/s]
62%|######2 | 206M/330M [00:20<00:11, 11.6MB/s]
63%|######2 | 207M/330M [00:21<00:10, 12.5MB/s]
63%|######3 | 208M/330M [00:21<00:09, 12.9MB/s]
64%|######3 | 210M/330M [00:21<00:09, 13.1MB/s]
64%|######3 | 211M/330M [00:21<00:14, 8.40MB/s]
64%|######4 | 213M/330M [00:21<00:12, 10.1MB/s]
65%|######4 | 214M/330M [00:21<00:13, 9.24MB/s]
65%|######5 | 216M/330M [00:21<00:11, 10.5MB/s]
66%|######5 | 217M/330M [00:22<00:10, 11.5MB/s]
66%|######6 | 219M/330M [00:22<00:09, 12.2MB/s]
67%|######6 | 220M/330M [00:22<00:09, 12.6MB/s]
67%|######7 | 222M/330M [00:22<00:13, 8.21MB/s]
68%|######7 | 223M/330M [00:22<00:11, 9.64MB/s]
68%|######7 | 224M/330M [00:22<00:12, 9.22MB/s]
68%|######8 | 226M/330M [00:22<00:09, 11.0MB/s]
69%|######8 | 227M/330M [00:23<00:09, 11.3MB/s]
69%|######9 | 229M/330M [00:23<00:08, 13.0MB/s]
70%|######9 | 230M/330M [00:23<00:08, 12.8MB/s]
70%|####### | 232M/330M [00:23<00:07, 14.1MB/s]
71%|####### | 234M/330M [00:23<00:11, 8.61MB/s]
71%|#######1 | 235M/330M [00:24<00:14, 6.98MB/s]
72%|#######1 | 237M/330M [00:24<00:12, 7.72MB/s]
72%|#######2 | 238M/330M [00:24<00:10, 9.00MB/s]
72%|#######2 | 239M/330M [00:24<00:11, 8.47MB/s]
73%|#######2 | 241M/330M [00:24<00:09, 9.57MB/s]
73%|#######3 | 242M/330M [00:24<00:13, 6.93MB/s]
74%|#######3 | 244M/330M [00:25<00:10, 8.55MB/s]
74%|#######4 | 245M/330M [00:25<00:10, 8.54MB/s]
75%|#######4 | 247M/330M [00:25<00:08, 10.7MB/s]
75%|#######5 | 248M/330M [00:25<00:07, 10.9MB/s]
76%|#######5 | 250M/330M [00:25<00:06, 12.7MB/s]
76%|#######5 | 251M/330M [00:25<00:06, 12.3MB/s]
76%|#######6 | 253M/330M [00:25<00:09, 8.69MB/s]
77%|#######6 | 254M/330M [00:26<00:08, 9.03MB/s]
77%|#######7 | 256M/330M [00:26<00:08, 9.15MB/s]
78%|#######7 | 257M/330M [00:26<00:07, 10.3MB/s]
78%|#######8 | 259M/330M [00:26<00:06, 11.3MB/s]
79%|#######8 | 260M/330M [00:26<00:06, 12.1MB/s]
79%|#######9 | 262M/330M [00:26<00:05, 12.6MB/s]
80%|#######9 | 263M/330M [00:26<00:05, 12.9MB/s]
80%|######## | 264M/330M [00:27<00:08, 8.33MB/s]
80%|######## | 266M/330M [00:27<00:07, 9.52MB/s]
81%|######## | 267M/330M [00:27<00:07, 9.20MB/s]
81%|########1 | 268M/330M [00:27<00:05, 10.8MB/s]
82%|########1 | 270M/330M [00:27<00:05, 11.5MB/s]
82%|########2 | 272M/330M [00:27<00:04, 12.8MB/s]
83%|########2 | 273M/330M [00:27<00:04, 12.9MB/s]
83%|########3 | 274M/330M [00:27<00:04, 13.8MB/s]
84%|########3 | 276M/330M [00:28<00:04, 13.3MB/s]
84%|########3 | 277M/330M [00:28<00:06, 8.79MB/s]
84%|########4 | 279M/330M [00:28<00:05, 9.90MB/s]
85%|########4 | 280M/330M [00:28<00:05, 9.78MB/s]
85%|########5 | 282M/330M [00:28<00:04, 10.6MB/s]
86%|########5 | 283M/330M [00:28<00:04, 11.6MB/s]
86%|########6 | 285M/330M [00:28<00:03, 12.5MB/s]
87%|########6 | 286M/330M [00:29<00:03, 13.1MB/s]
87%|########7 | 288M/330M [00:29<00:03, 13.4MB/s]
88%|########7 | 289M/330M [00:29<00:05, 8.49MB/s]
88%|########7 | 291M/330M [00:29<00:04, 9.71MB/s]
88%|########8 | 292M/330M [00:29<00:04, 9.42MB/s]
89%|########8 | 294M/330M [00:29<00:03, 11.4MB/s]
89%|########9 | 295M/330M [00:29<00:03, 11.7MB/s]
90%|########9 | 297M/330M [00:30<00:02, 13.5MB/s]
90%|######### | 299M/330M [00:30<00:02, 13.1MB/s]
91%|######### | 300M/330M [00:30<00:02, 14.2MB/s]
91%|#########1| 302M/330M [00:30<00:04, 6.87MB/s]
92%|#########2| 304M/330M [00:31<00:03, 7.91MB/s]
92%|#########2| 306M/330M [00:31<00:02, 9.06MB/s]
93%|#########2| 307M/330M [00:31<00:02, 9.76MB/s]
93%|#########3| 308M/330M [00:31<00:02, 10.9MB/s]
94%|#########3| 310M/330M [00:31<00:01, 11.8MB/s]
94%|#########4| 311M/330M [00:31<00:01, 12.7MB/s]
95%|#########4| 313M/330M [00:31<00:01, 13.3MB/s]
95%|#########5| 314M/330M [00:31<00:01, 13.6MB/s]
96%|#########5| 316M/330M [00:31<00:01, 13.9MB/s]
96%|#########6| 317M/330M [00:32<00:02, 6.44MB/s]
97%|#########6| 319M/330M [00:32<00:01, 7.77MB/s]
97%|#########6| 320M/330M [00:32<00:01, 8.67MB/s]
97%|#########7| 322M/330M [00:32<00:00, 10.0MB/s]
98%|#########7| 323M/330M [00:32<00:00, 10.9MB/s]
98%|#########8| 324M/330M [00:32<00:00, 11.9MB/s]
99%|#########8| 326M/330M [00:33<00:00, 12.5MB/s]
99%|#########9| 327M/330M [00:33<00:00, 13.2MB/s]
100%|#########9| 329M/330M [00:33<00:00, 13.4MB/s]
100%|##########| 330M/330M [00:33<00:00, 10.4MB/s]
269
输出结果应该是269,根据ImageNet的`类别索引到标签文件 <https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a>`_,它对应于``timber wolf, grey wolf, gray wolf, Canis lupus``。
现在我们已经验证了可以使用DeiT模型对图像进行分类,接下来看看如何修改该模型以便可以在iOS和Android应用程序上运行。
脚本化DeiT¶
为了在移动设备上使用模型,我们首先需要脚本化模型。请参阅`脚本和优化指南 <https://pytorch.org/tutorials/recipes/script_optimized.html>`_进行快速了解。运行下面的代码,将上一步骤中使用的DeiT模型转换为可在移动设备上运行的TorchScript格式。
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /home/user/.cache/torch/hub/facebookresearch_deit_main
生成了大小约为346MB的脚本化模型文件``fbdeit_scripted.pt``。
量化DeiT¶
为了在保持推理精度基本相同的情况下显著减小训练模型的大小,可以对模型应用量化技术。得益于DeiT使用的Transformer模型,对其应用动态量化很方便,因为动态量化对LSTM和Transformer模型效果最佳(查看`此处 <https://pytorch.org/docs/stable/quantization.html?highlight=quantization#dynamic-quantization>`_了解更多详情)。
现在运行下面的代码:
# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/ao/quantization/observer.py:244: UserWarning:
Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.
生成的脚本化和量化版本模型``fbdeit_quantized_scripted.pt``,大小约为89MB,比非量化版本的346MB减少了74%!
您可以使用``scripted_quantized_model``生成相同的推理结果:
269
优化DeiT¶
在使用量化和脚本化模型于移动设备前的最后一步是优化模型:
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
生成的``fbdeit_optimized_scripted_quantized.pt``文件大小与量化、脚本化但未优化的模型基本相同。推理结果保持不变。
269
使用精简版解释器¶
为了看看模型大小减少和推理速度提升的效果,使用精简版解释器创建模型的精简版。
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
尽管精简版模型的大小与非精简版本相当,但在移动设备上运行精简版时,推理速度预计会更快。
比较推理速度¶
为了比较四种模型(原始模型、脚本化模型、量化并脚本化模型及优化量化脚本化模型)的推理速度,运行以下代码:
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
out = ptl(img)
print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 42.05ms
scripted model: 38.41ms
scripted & quantized model: 54.18ms
scripted & quantized & optimized model: 25.90ms
lite model: 24.74ms
在Google Colab上的运行结果为:
original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms
以下结果总结了每种模型的推理时间以及每种模型相对于原始模型的百分比减少。
import pandas as pd
import numpy as np
df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
columns=['Inference Time', 'Reduction'])], axis=1)
print(df)
"""
Model Inference Time Reduction
0 original model 1236.69ms 0%
1 scripted model 1226.72ms 0.81%
2 scripted & quantized model 593.19ms 52.03%
3 scripted & quantized & optimized model 598.01ms 51.64%
4 lite model 600.72ms 51.43%
"""
Model Inference Time Reduction
0 original model 42.05ms 0%
1 scripted model 38.41ms 8.66%
2 scripted & quantized model 54.18ms -28.83%
3 scripted & quantized & optimized model 25.90ms 38.41%
4 lite model 24.74ms 41.16%
'\n Model Inference Time Reduction\n0\toriginal model 1236.69ms 0%\n1\tscripted model 1226.72ms 0.81%\n2\tscripted & quantized model 593.19ms 52.03%\n3\tscripted & quantized & optimized model 598.01ms 51.64%\n4\tlite model 600.72ms 51.43%\n'