Shortcuts

学习基础知识 || 快速入门 || 张量 || 数据集与数据加载器 || 变换 || 构建模型 || 自动微分 || 优化 || 保存和加载模型

保存和加载模型

Created On: Feb 09, 2021 | Last Updated: Oct 15, 2024 | Last Verified: Nov 05, 2024

在本节中,我们将讨论如何通过保存、加载和运行模型预测来持久化模型状态。

import torch
import torchvision.models as models

保存和加载模型权重

PyTorch 模型将学习到的参数存储在一个称为 state_dict 的内部状态字典中。可以通过 torch.save 方法将其持久化:

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/user/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  0%|          | 128k/528M [00:00<16:47, 549kB/s]
  0%|          | 384k/528M [00:00<08:14, 1.12MB/s]
  0%|          | 896k/528M [00:00<04:08, 2.22MB/s]
  0%|          | 1.75M/528M [00:00<02:10, 4.24MB/s]
  1%|          | 3.50M/528M [00:00<01:06, 8.32MB/s]
  1%|1         | 6.50M/528M [00:00<00:36, 14.9MB/s]
  2%|1         | 9.50M/528M [00:00<00:27, 19.5MB/s]
  2%|2         | 12.5M/528M [00:01<00:23, 22.6MB/s]
  3%|2         | 15.5M/528M [00:01<00:21, 24.8MB/s]
  4%|3         | 18.5M/528M [00:01<00:20, 26.5MB/s]
  4%|4         | 21.4M/528M [00:01<00:19, 27.3MB/s]
  5%|4         | 24.4M/528M [00:01<00:18, 28.1MB/s]
  5%|5         | 27.4M/528M [00:01<00:18, 28.7MB/s]
  6%|5         | 30.4M/528M [00:01<00:17, 29.1MB/s]
  6%|6         | 33.4M/528M [00:01<00:17, 29.4MB/s]
  7%|6         | 36.4M/528M [00:01<00:17, 29.7MB/s]
  7%|7         | 39.4M/528M [00:01<00:17, 29.8MB/s]
  8%|8         | 42.4M/528M [00:02<00:17, 29.9MB/s]
  9%|8         | 45.4M/528M [00:02<00:16, 30.1MB/s]
  9%|9         | 48.2M/528M [00:02<00:16, 30.1MB/s]
 10%|9         | 51.2M/528M [00:02<00:16, 30.1MB/s]
 10%|#         | 54.2M/528M [00:02<00:16, 30.1MB/s]
 11%|#         | 57.2M/528M [00:02<00:16, 30.1MB/s]
 11%|#1        | 60.2M/528M [00:02<00:16, 30.1MB/s]
 12%|#1        | 63.2M/528M [00:02<00:16, 30.0MB/s]
 13%|#2        | 66.2M/528M [00:02<00:16, 30.1MB/s]
 13%|#3        | 69.2M/528M [00:03<00:15, 30.1MB/s]
 14%|#3        | 72.2M/528M [00:03<00:15, 30.1MB/s]
 14%|#4        | 75.2M/528M [00:03<00:15, 30.1MB/s]
 15%|#4        | 78.2M/528M [00:03<00:15, 30.1MB/s]
 15%|#5        | 81.2M/528M [00:03<00:15, 30.2MB/s]
 16%|#5        | 84.2M/528M [00:03<00:15, 30.2MB/s]
 17%|#6        | 87.2M/528M [00:03<00:15, 30.1MB/s]
 17%|#7        | 90.2M/528M [00:03<00:15, 30.3MB/s]
 18%|#7        | 93.2M/528M [00:03<00:15, 30.2MB/s]
 18%|#8        | 96.2M/528M [00:03<00:14, 30.2MB/s]
 19%|#8        | 99.2M/528M [00:04<00:14, 30.1MB/s]
 19%|#9        | 102M/528M [00:04<00:14, 30.1MB/s]
 20%|#9        | 105M/528M [00:04<00:14, 30.1MB/s]
 20%|##        | 108M/528M [00:04<00:14, 30.1MB/s]
 21%|##1       | 111M/528M [00:04<00:14, 30.1MB/s]
 22%|##1       | 114M/528M [00:04<00:14, 30.2MB/s]
 22%|##2       | 117M/528M [00:04<00:14, 30.1MB/s]
 23%|##2       | 120M/528M [00:04<00:14, 30.0MB/s]
 23%|##3       | 123M/528M [00:04<00:14, 30.1MB/s]
 24%|##3       | 126M/528M [00:04<00:13, 30.1MB/s]
 24%|##4       | 129M/528M [00:05<00:13, 30.1MB/s]
 25%|##5       | 132M/528M [00:05<00:13, 30.1MB/s]
 26%|##5       | 135M/528M [00:05<00:13, 30.1MB/s]
 26%|##6       | 138M/528M [00:05<00:13, 30.3MB/s]
 27%|##6       | 141M/528M [00:05<00:13, 30.2MB/s]
 27%|##7       | 144M/528M [00:05<00:13, 30.2MB/s]
 28%|##7       | 147M/528M [00:05<00:13, 30.2MB/s]
 28%|##8       | 150M/528M [00:05<00:13, 30.1MB/s]
 29%|##8       | 153M/528M [00:05<00:13, 30.2MB/s]
 30%|##9       | 156M/528M [00:06<00:12, 30.2MB/s]
 30%|###       | 159M/528M [00:06<00:12, 30.2MB/s]
 31%|###       | 162M/528M [00:06<00:12, 30.1MB/s]
 31%|###1      | 165M/528M [00:06<00:12, 30.3MB/s]
 32%|###1      | 168M/528M [00:06<00:12, 30.2MB/s]
 32%|###2      | 171M/528M [00:06<00:12, 30.2MB/s]
 33%|###2      | 174M/528M [00:06<00:12, 30.2MB/s]
 34%|###3      | 177M/528M [00:06<00:12, 30.2MB/s]
 34%|###4      | 180M/528M [00:06<00:12, 30.2MB/s]
 35%|###4      | 183M/528M [00:06<00:11, 30.2MB/s]
 35%|###5      | 186M/528M [00:07<00:11, 30.2MB/s]
 36%|###5      | 189M/528M [00:07<00:11, 30.2MB/s]
 36%|###6      | 192M/528M [00:07<00:11, 30.1MB/s]
 37%|###6      | 195M/528M [00:07<00:11, 30.2MB/s]
 38%|###7      | 198M/528M [00:07<00:11, 30.2MB/s]
 38%|###8      | 201M/528M [00:07<00:11, 30.1MB/s]
 39%|###8      | 204M/528M [00:07<00:11, 30.0MB/s]
 39%|###9      | 207M/528M [00:07<00:11, 30.1MB/s]
 40%|###9      | 210M/528M [00:07<00:11, 30.1MB/s]
 40%|####      | 213M/528M [00:08<00:11, 29.0MB/s]
 41%|####      | 216M/528M [00:08<00:11, 29.2MB/s]
 41%|####1     | 219M/528M [00:08<00:11, 29.4MB/s]
 42%|####1     | 222M/528M [00:08<00:10, 29.6MB/s]
 43%|####2     | 224M/528M [00:08<00:10, 29.8MB/s]
 43%|####3     | 228M/528M [00:08<00:10, 29.9MB/s]
 44%|####3     | 230M/528M [00:08<00:10, 30.0MB/s]
 44%|####4     | 233M/528M [00:08<00:10, 30.0MB/s]
 45%|####4     | 236M/528M [00:08<00:10, 30.1MB/s]
 45%|####5     | 239M/528M [00:08<00:10, 30.1MB/s]
 46%|####5     | 242M/528M [00:09<00:09, 30.1MB/s]
 46%|####6     | 245M/528M [00:09<00:09, 30.1MB/s]
 47%|####7     | 248M/528M [00:09<00:09, 30.1MB/s]
 48%|####7     | 251M/528M [00:09<00:09, 30.2MB/s]
 48%|####8     | 254M/528M [00:09<00:09, 30.1MB/s]
 49%|####8     | 257M/528M [00:09<00:09, 30.1MB/s]
 49%|####9     | 260M/528M [00:09<00:09, 30.2MB/s]
 50%|####9     | 263M/528M [00:09<00:09, 30.2MB/s]
 50%|#####     | 266M/528M [00:09<00:09, 30.1MB/s]
 51%|#####1    | 269M/528M [00:09<00:08, 30.1MB/s]
 52%|#####1    | 272M/528M [00:10<00:08, 30.2MB/s]
 52%|#####2    | 275M/528M [00:10<00:08, 30.2MB/s]
 53%|#####2    | 278M/528M [00:10<00:08, 30.2MB/s]
 53%|#####3    | 281M/528M [00:10<00:08, 30.2MB/s]
 54%|#####3    | 284M/528M [00:10<00:08, 30.3MB/s]
 54%|#####4    | 287M/528M [00:10<00:08, 30.3MB/s]
 55%|#####5    | 290M/528M [00:10<00:08, 30.2MB/s]
 56%|#####5    | 293M/528M [00:10<00:08, 30.1MB/s]
 56%|#####6    | 296M/528M [00:10<00:08, 30.1MB/s]
 57%|#####6    | 299M/528M [00:11<00:07, 30.1MB/s]
 57%|#####7    | 302M/528M [00:11<00:07, 30.1MB/s]
 58%|#####7    | 305M/528M [00:11<00:16, 13.8MB/s]
 58%|#####8    | 308M/528M [00:11<00:14, 16.4MB/s]
 59%|#####8    | 311M/528M [00:11<00:11, 19.1MB/s]
 60%|#####9    | 314M/528M [00:11<00:10, 21.4MB/s]
 60%|######    | 317M/528M [00:12<00:09, 23.5MB/s]
 61%|######    | 320M/528M [00:12<00:08, 25.0MB/s]
 61%|######1   | 323M/528M [00:12<00:08, 26.4MB/s]
 62%|######1   | 326M/528M [00:12<00:07, 27.4MB/s]
 62%|######2   | 329M/528M [00:12<00:07, 28.3MB/s]
 63%|######2   | 332M/528M [00:12<00:07, 28.8MB/s]
 63%|######3   | 335M/528M [00:12<00:06, 29.2MB/s]
 64%|######4   | 338M/528M [00:12<00:06, 29.4MB/s]
 65%|######4   | 341M/528M [00:12<00:06, 29.7MB/s]
 65%|######5   | 344M/528M [00:12<00:06, 29.8MB/s]
 66%|######5   | 347M/528M [00:13<00:06, 29.9MB/s]
 66%|######6   | 350M/528M [00:13<00:06, 30.0MB/s]
 67%|######6   | 353M/528M [00:13<00:06, 30.0MB/s]
 67%|######7   | 356M/528M [00:13<00:05, 30.1MB/s]
 68%|######8   | 359M/528M [00:13<00:05, 30.1MB/s]
 69%|######8   | 362M/528M [00:13<00:05, 30.1MB/s]
 69%|######9   | 365M/528M [00:13<00:05, 30.1MB/s]
 70%|######9   | 368M/528M [00:13<00:05, 30.2MB/s]
 70%|#######   | 371M/528M [00:13<00:05, 30.2MB/s]
 71%|#######   | 374M/528M [00:14<00:05, 30.2MB/s]
 71%|#######1  | 377M/528M [00:14<00:05, 30.2MB/s]
 72%|#######1  | 380M/528M [00:14<00:05, 30.3MB/s]
 73%|#######2  | 383M/528M [00:14<00:05, 30.2MB/s]
 73%|#######3  | 386M/528M [00:14<00:04, 30.2MB/s]
 74%|#######3  | 389M/528M [00:14<00:04, 30.1MB/s]
 74%|#######4  | 392M/528M [00:14<00:04, 30.1MB/s]
 75%|#######4  | 395M/528M [00:14<00:04, 30.1MB/s]
 75%|#######5  | 398M/528M [00:14<00:04, 30.1MB/s]
 76%|#######5  | 401M/528M [00:14<00:04, 30.0MB/s]
 77%|#######6  | 404M/528M [00:15<00:04, 30.1MB/s]
 77%|#######7  | 407M/528M [00:15<00:04, 30.2MB/s]
 78%|#######7  | 410M/528M [00:15<00:04, 30.2MB/s]
 78%|#######8  | 413M/528M [00:15<00:03, 30.2MB/s]
 79%|#######8  | 416M/528M [00:15<00:03, 30.2MB/s]
 79%|#######9  | 419M/528M [00:15<00:03, 30.2MB/s]
 80%|#######9  | 422M/528M [00:15<00:03, 30.2MB/s]
 81%|########  | 425M/528M [00:15<00:03, 30.2MB/s]
 81%|########1 | 428M/528M [00:15<00:03, 30.2MB/s]
 82%|########1 | 431M/528M [00:16<00:03, 30.2MB/s]
 82%|########2 | 434M/528M [00:16<00:03, 30.2MB/s]
 83%|########2 | 437M/528M [00:16<00:03, 30.2MB/s]
 83%|########3 | 440M/528M [00:16<00:03, 30.2MB/s]
 84%|########3 | 443M/528M [00:16<00:02, 30.2MB/s]
 85%|########4 | 446M/528M [00:16<00:02, 30.1MB/s]
 85%|########5 | 449M/528M [00:16<00:02, 30.2MB/s]
 86%|########5 | 452M/528M [00:16<00:02, 30.1MB/s]
 86%|########6 | 455M/528M [00:16<00:02, 30.1MB/s]
 87%|########6 | 458M/528M [00:16<00:02, 30.2MB/s]
 87%|########7 | 461M/528M [00:17<00:02, 30.2MB/s]
 88%|########7 | 464M/528M [00:17<00:02, 30.2MB/s]
 88%|########8 | 467M/528M [00:17<00:02, 30.2MB/s]
 89%|########9 | 470M/528M [00:17<00:02, 30.1MB/s]
 90%|########9 | 473M/528M [00:17<00:01, 30.1MB/s]
 90%|######### | 476M/528M [00:17<00:01, 30.1MB/s]
 91%|######### | 479M/528M [00:17<00:01, 30.0MB/s]
 91%|#########1| 482M/528M [00:17<00:01, 24.8MB/s]
 92%|#########1| 484M/528M [00:17<00:01, 23.2MB/s]
 92%|#########2| 487M/528M [00:18<00:01, 22.9MB/s]
 93%|#########2| 490M/528M [00:18<00:01, 24.8MB/s]
 93%|#########3| 493M/528M [00:18<00:01, 26.3MB/s]
 94%|#########3| 496M/528M [00:18<00:01, 27.4MB/s]
 95%|#########4| 499M/528M [00:18<00:01, 28.2MB/s]
 95%|#########5| 502M/528M [00:18<00:00, 28.7MB/s]
 96%|#########5| 505M/528M [00:18<00:00, 29.1MB/s]
 96%|#########6| 508M/528M [00:18<00:00, 29.4MB/s]
 97%|#########6| 511M/528M [00:18<00:00, 29.6MB/s]
 97%|#########7| 514M/528M [00:19<00:00, 29.7MB/s]
 98%|#########7| 517M/528M [00:19<00:00, 29.9MB/s]
 98%|#########8| 520M/528M [00:19<00:00, 30.0MB/s]
 99%|#########9| 523M/528M [00:19<00:00, 30.1MB/s]
100%|#########9| 526M/528M [00:19<00:00, 30.0MB/s]
100%|##########| 528M/528M [00:19<00:00, 28.3MB/s]

要加载模型权重,您需要首先创建相同模型的一个实例,然后使用 load_state_dict() 方法加载参数。

在下面的代码中,我们设置 weights_only=True,以限制在取消序列化过程中仅执行加载权重所需的函数。使用 weights_only=True 是加载权重时的最佳实践。

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
model.eval()
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

备注

在推理之前,一定要调用 model.eval() 方法,以将 dropout 和批量归一化层设置为评估模式。未进行此操作会导致推理结果不一致。

保存和加载包含结构的模型

在加载模型权重时,我们需要先实例化模型类,因为该类定义了网络的结构。我们可能希望将类的结构与模型一起保存,这可以通过将 model``(而不是 ``model.state_dict())传递给保存函数来实现:

torch.save(model, 'model.pth')

我们可以按照下面的示例代码加载模型。

保存和加载 torch.nn.Modules 中所述,保存 state_dict 被认为是最佳实践。然而,在下面的例子中我们使用 weights_only=False,因为这涉及到加载模型,这是对 torch.save 的一个遗留用例。

model = torch.load('model.pth', weights_only=False),

备注

此方法使用 Python pickle 模块进行模型序列化,因此在加载模型时需要实际的类定义可用。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源