备注
点击 这里 下载完整示例代码
学习基础知识 || 快速入门 || 张量 || 数据集与数据加载器 || 变换 || 构建模型 || 自动微分 || 优化 || 保存和加载模型
保存和加载模型¶
Created On: Feb 09, 2021 | Last Updated: Oct 15, 2024 | Last Verified: Nov 05, 2024
在本节中,我们将讨论如何通过保存、加载和运行模型预测来持久化模型状态。
import torch
import torchvision.models as models
保存和加载模型权重¶
PyTorch 模型将学习到的参数存储在一个称为 state_dict
的内部状态字典中。可以通过 torch.save
方法将其持久化:
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/user/.cache/torch/hub/checkpoints/vgg16-397923af.pth
0%| | 0.00/528M [00:00<?, ?B/s]
0%| | 128k/528M [00:00<16:47, 549kB/s]
0%| | 384k/528M [00:00<08:14, 1.12MB/s]
0%| | 896k/528M [00:00<04:08, 2.22MB/s]
0%| | 1.75M/528M [00:00<02:10, 4.24MB/s]
1%| | 3.50M/528M [00:00<01:06, 8.32MB/s]
1%|1 | 6.50M/528M [00:00<00:36, 14.9MB/s]
2%|1 | 9.50M/528M [00:00<00:27, 19.5MB/s]
2%|2 | 12.5M/528M [00:01<00:23, 22.6MB/s]
3%|2 | 15.5M/528M [00:01<00:21, 24.8MB/s]
4%|3 | 18.5M/528M [00:01<00:20, 26.5MB/s]
4%|4 | 21.4M/528M [00:01<00:19, 27.3MB/s]
5%|4 | 24.4M/528M [00:01<00:18, 28.1MB/s]
5%|5 | 27.4M/528M [00:01<00:18, 28.7MB/s]
6%|5 | 30.4M/528M [00:01<00:17, 29.1MB/s]
6%|6 | 33.4M/528M [00:01<00:17, 29.4MB/s]
7%|6 | 36.4M/528M [00:01<00:17, 29.7MB/s]
7%|7 | 39.4M/528M [00:01<00:17, 29.8MB/s]
8%|8 | 42.4M/528M [00:02<00:17, 29.9MB/s]
9%|8 | 45.4M/528M [00:02<00:16, 30.1MB/s]
9%|9 | 48.2M/528M [00:02<00:16, 30.1MB/s]
10%|9 | 51.2M/528M [00:02<00:16, 30.1MB/s]
10%|# | 54.2M/528M [00:02<00:16, 30.1MB/s]
11%|# | 57.2M/528M [00:02<00:16, 30.1MB/s]
11%|#1 | 60.2M/528M [00:02<00:16, 30.1MB/s]
12%|#1 | 63.2M/528M [00:02<00:16, 30.0MB/s]
13%|#2 | 66.2M/528M [00:02<00:16, 30.1MB/s]
13%|#3 | 69.2M/528M [00:03<00:15, 30.1MB/s]
14%|#3 | 72.2M/528M [00:03<00:15, 30.1MB/s]
14%|#4 | 75.2M/528M [00:03<00:15, 30.1MB/s]
15%|#4 | 78.2M/528M [00:03<00:15, 30.1MB/s]
15%|#5 | 81.2M/528M [00:03<00:15, 30.2MB/s]
16%|#5 | 84.2M/528M [00:03<00:15, 30.2MB/s]
17%|#6 | 87.2M/528M [00:03<00:15, 30.1MB/s]
17%|#7 | 90.2M/528M [00:03<00:15, 30.3MB/s]
18%|#7 | 93.2M/528M [00:03<00:15, 30.2MB/s]
18%|#8 | 96.2M/528M [00:03<00:14, 30.2MB/s]
19%|#8 | 99.2M/528M [00:04<00:14, 30.1MB/s]
19%|#9 | 102M/528M [00:04<00:14, 30.1MB/s]
20%|#9 | 105M/528M [00:04<00:14, 30.1MB/s]
20%|## | 108M/528M [00:04<00:14, 30.1MB/s]
21%|##1 | 111M/528M [00:04<00:14, 30.1MB/s]
22%|##1 | 114M/528M [00:04<00:14, 30.2MB/s]
22%|##2 | 117M/528M [00:04<00:14, 30.1MB/s]
23%|##2 | 120M/528M [00:04<00:14, 30.0MB/s]
23%|##3 | 123M/528M [00:04<00:14, 30.1MB/s]
24%|##3 | 126M/528M [00:04<00:13, 30.1MB/s]
24%|##4 | 129M/528M [00:05<00:13, 30.1MB/s]
25%|##5 | 132M/528M [00:05<00:13, 30.1MB/s]
26%|##5 | 135M/528M [00:05<00:13, 30.1MB/s]
26%|##6 | 138M/528M [00:05<00:13, 30.3MB/s]
27%|##6 | 141M/528M [00:05<00:13, 30.2MB/s]
27%|##7 | 144M/528M [00:05<00:13, 30.2MB/s]
28%|##7 | 147M/528M [00:05<00:13, 30.2MB/s]
28%|##8 | 150M/528M [00:05<00:13, 30.1MB/s]
29%|##8 | 153M/528M [00:05<00:13, 30.2MB/s]
30%|##9 | 156M/528M [00:06<00:12, 30.2MB/s]
30%|### | 159M/528M [00:06<00:12, 30.2MB/s]
31%|### | 162M/528M [00:06<00:12, 30.1MB/s]
31%|###1 | 165M/528M [00:06<00:12, 30.3MB/s]
32%|###1 | 168M/528M [00:06<00:12, 30.2MB/s]
32%|###2 | 171M/528M [00:06<00:12, 30.2MB/s]
33%|###2 | 174M/528M [00:06<00:12, 30.2MB/s]
34%|###3 | 177M/528M [00:06<00:12, 30.2MB/s]
34%|###4 | 180M/528M [00:06<00:12, 30.2MB/s]
35%|###4 | 183M/528M [00:06<00:11, 30.2MB/s]
35%|###5 | 186M/528M [00:07<00:11, 30.2MB/s]
36%|###5 | 189M/528M [00:07<00:11, 30.2MB/s]
36%|###6 | 192M/528M [00:07<00:11, 30.1MB/s]
37%|###6 | 195M/528M [00:07<00:11, 30.2MB/s]
38%|###7 | 198M/528M [00:07<00:11, 30.2MB/s]
38%|###8 | 201M/528M [00:07<00:11, 30.1MB/s]
39%|###8 | 204M/528M [00:07<00:11, 30.0MB/s]
39%|###9 | 207M/528M [00:07<00:11, 30.1MB/s]
40%|###9 | 210M/528M [00:07<00:11, 30.1MB/s]
40%|#### | 213M/528M [00:08<00:11, 29.0MB/s]
41%|#### | 216M/528M [00:08<00:11, 29.2MB/s]
41%|####1 | 219M/528M [00:08<00:11, 29.4MB/s]
42%|####1 | 222M/528M [00:08<00:10, 29.6MB/s]
43%|####2 | 224M/528M [00:08<00:10, 29.8MB/s]
43%|####3 | 228M/528M [00:08<00:10, 29.9MB/s]
44%|####3 | 230M/528M [00:08<00:10, 30.0MB/s]
44%|####4 | 233M/528M [00:08<00:10, 30.0MB/s]
45%|####4 | 236M/528M [00:08<00:10, 30.1MB/s]
45%|####5 | 239M/528M [00:08<00:10, 30.1MB/s]
46%|####5 | 242M/528M [00:09<00:09, 30.1MB/s]
46%|####6 | 245M/528M [00:09<00:09, 30.1MB/s]
47%|####7 | 248M/528M [00:09<00:09, 30.1MB/s]
48%|####7 | 251M/528M [00:09<00:09, 30.2MB/s]
48%|####8 | 254M/528M [00:09<00:09, 30.1MB/s]
49%|####8 | 257M/528M [00:09<00:09, 30.1MB/s]
49%|####9 | 260M/528M [00:09<00:09, 30.2MB/s]
50%|####9 | 263M/528M [00:09<00:09, 30.2MB/s]
50%|##### | 266M/528M [00:09<00:09, 30.1MB/s]
51%|#####1 | 269M/528M [00:09<00:08, 30.1MB/s]
52%|#####1 | 272M/528M [00:10<00:08, 30.2MB/s]
52%|#####2 | 275M/528M [00:10<00:08, 30.2MB/s]
53%|#####2 | 278M/528M [00:10<00:08, 30.2MB/s]
53%|#####3 | 281M/528M [00:10<00:08, 30.2MB/s]
54%|#####3 | 284M/528M [00:10<00:08, 30.3MB/s]
54%|#####4 | 287M/528M [00:10<00:08, 30.3MB/s]
55%|#####5 | 290M/528M [00:10<00:08, 30.2MB/s]
56%|#####5 | 293M/528M [00:10<00:08, 30.1MB/s]
56%|#####6 | 296M/528M [00:10<00:08, 30.1MB/s]
57%|#####6 | 299M/528M [00:11<00:07, 30.1MB/s]
57%|#####7 | 302M/528M [00:11<00:07, 30.1MB/s]
58%|#####7 | 305M/528M [00:11<00:16, 13.8MB/s]
58%|#####8 | 308M/528M [00:11<00:14, 16.4MB/s]
59%|#####8 | 311M/528M [00:11<00:11, 19.1MB/s]
60%|#####9 | 314M/528M [00:11<00:10, 21.4MB/s]
60%|###### | 317M/528M [00:12<00:09, 23.5MB/s]
61%|###### | 320M/528M [00:12<00:08, 25.0MB/s]
61%|######1 | 323M/528M [00:12<00:08, 26.4MB/s]
62%|######1 | 326M/528M [00:12<00:07, 27.4MB/s]
62%|######2 | 329M/528M [00:12<00:07, 28.3MB/s]
63%|######2 | 332M/528M [00:12<00:07, 28.8MB/s]
63%|######3 | 335M/528M [00:12<00:06, 29.2MB/s]
64%|######4 | 338M/528M [00:12<00:06, 29.4MB/s]
65%|######4 | 341M/528M [00:12<00:06, 29.7MB/s]
65%|######5 | 344M/528M [00:12<00:06, 29.8MB/s]
66%|######5 | 347M/528M [00:13<00:06, 29.9MB/s]
66%|######6 | 350M/528M [00:13<00:06, 30.0MB/s]
67%|######6 | 353M/528M [00:13<00:06, 30.0MB/s]
67%|######7 | 356M/528M [00:13<00:05, 30.1MB/s]
68%|######8 | 359M/528M [00:13<00:05, 30.1MB/s]
69%|######8 | 362M/528M [00:13<00:05, 30.1MB/s]
69%|######9 | 365M/528M [00:13<00:05, 30.1MB/s]
70%|######9 | 368M/528M [00:13<00:05, 30.2MB/s]
70%|####### | 371M/528M [00:13<00:05, 30.2MB/s]
71%|####### | 374M/528M [00:14<00:05, 30.2MB/s]
71%|#######1 | 377M/528M [00:14<00:05, 30.2MB/s]
72%|#######1 | 380M/528M [00:14<00:05, 30.3MB/s]
73%|#######2 | 383M/528M [00:14<00:05, 30.2MB/s]
73%|#######3 | 386M/528M [00:14<00:04, 30.2MB/s]
74%|#######3 | 389M/528M [00:14<00:04, 30.1MB/s]
74%|#######4 | 392M/528M [00:14<00:04, 30.1MB/s]
75%|#######4 | 395M/528M [00:14<00:04, 30.1MB/s]
75%|#######5 | 398M/528M [00:14<00:04, 30.1MB/s]
76%|#######5 | 401M/528M [00:14<00:04, 30.0MB/s]
77%|#######6 | 404M/528M [00:15<00:04, 30.1MB/s]
77%|#######7 | 407M/528M [00:15<00:04, 30.2MB/s]
78%|#######7 | 410M/528M [00:15<00:04, 30.2MB/s]
78%|#######8 | 413M/528M [00:15<00:03, 30.2MB/s]
79%|#######8 | 416M/528M [00:15<00:03, 30.2MB/s]
79%|#######9 | 419M/528M [00:15<00:03, 30.2MB/s]
80%|#######9 | 422M/528M [00:15<00:03, 30.2MB/s]
81%|######## | 425M/528M [00:15<00:03, 30.2MB/s]
81%|########1 | 428M/528M [00:15<00:03, 30.2MB/s]
82%|########1 | 431M/528M [00:16<00:03, 30.2MB/s]
82%|########2 | 434M/528M [00:16<00:03, 30.2MB/s]
83%|########2 | 437M/528M [00:16<00:03, 30.2MB/s]
83%|########3 | 440M/528M [00:16<00:03, 30.2MB/s]
84%|########3 | 443M/528M [00:16<00:02, 30.2MB/s]
85%|########4 | 446M/528M [00:16<00:02, 30.1MB/s]
85%|########5 | 449M/528M [00:16<00:02, 30.2MB/s]
86%|########5 | 452M/528M [00:16<00:02, 30.1MB/s]
86%|########6 | 455M/528M [00:16<00:02, 30.1MB/s]
87%|########6 | 458M/528M [00:16<00:02, 30.2MB/s]
87%|########7 | 461M/528M [00:17<00:02, 30.2MB/s]
88%|########7 | 464M/528M [00:17<00:02, 30.2MB/s]
88%|########8 | 467M/528M [00:17<00:02, 30.2MB/s]
89%|########9 | 470M/528M [00:17<00:02, 30.1MB/s]
90%|########9 | 473M/528M [00:17<00:01, 30.1MB/s]
90%|######### | 476M/528M [00:17<00:01, 30.1MB/s]
91%|######### | 479M/528M [00:17<00:01, 30.0MB/s]
91%|#########1| 482M/528M [00:17<00:01, 24.8MB/s]
92%|#########1| 484M/528M [00:17<00:01, 23.2MB/s]
92%|#########2| 487M/528M [00:18<00:01, 22.9MB/s]
93%|#########2| 490M/528M [00:18<00:01, 24.8MB/s]
93%|#########3| 493M/528M [00:18<00:01, 26.3MB/s]
94%|#########3| 496M/528M [00:18<00:01, 27.4MB/s]
95%|#########4| 499M/528M [00:18<00:01, 28.2MB/s]
95%|#########5| 502M/528M [00:18<00:00, 28.7MB/s]
96%|#########5| 505M/528M [00:18<00:00, 29.1MB/s]
96%|#########6| 508M/528M [00:18<00:00, 29.4MB/s]
97%|#########6| 511M/528M [00:18<00:00, 29.6MB/s]
97%|#########7| 514M/528M [00:19<00:00, 29.7MB/s]
98%|#########7| 517M/528M [00:19<00:00, 29.9MB/s]
98%|#########8| 520M/528M [00:19<00:00, 30.0MB/s]
99%|#########9| 523M/528M [00:19<00:00, 30.1MB/s]
100%|#########9| 526M/528M [00:19<00:00, 30.0MB/s]
100%|##########| 528M/528M [00:19<00:00, 28.3MB/s]
要加载模型权重,您需要首先创建相同模型的一个实例,然后使用 load_state_dict()
方法加载参数。
在下面的代码中,我们设置 weights_only=True
,以限制在取消序列化过程中仅执行加载权重所需的函数。使用 weights_only=True
是加载权重时的最佳实践。
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
备注
在推理之前,一定要调用 model.eval()
方法,以将 dropout 和批量归一化层设置为评估模式。未进行此操作会导致推理结果不一致。
保存和加载包含结构的模型¶
在加载模型权重时,我们需要先实例化模型类,因为该类定义了网络的结构。我们可能希望将类的结构与模型一起保存,这可以通过将 model``(而不是 ``model.state_dict()
)传递给保存函数来实现:
torch.save(model, 'model.pth')
我们可以按照下面的示例代码加载模型。
如 保存和加载 torch.nn.Modules 中所述,保存 state_dict
被认为是最佳实践。然而,在下面的例子中我们使用 weights_only=False
,因为这涉及到加载模型,这是对 torch.save
的一个遗留用例。
model = torch.load('model.pth', weights_only=False),
备注
此方法使用 Python pickle 模块进行模型序列化,因此在加载模型时需要实际的类定义可用。