备注
点击 这里 下载完整示例代码
如何通过将优化器步骤融合到反向传递中节省内存。¶
Created On: Oct 02, 2023 | Last Updated: Jan 16, 2024 | Last Verified: Nov 05, 2024
你好!本教程旨在展示通过减少 梯度 所占的内存来减少训练循环的内存占用的一种方法。假设您有一个模型,并且您对优化内存以避免 内存不足
(OOM) 错误或仅仅从 GPU 中获取更多内容感兴趣。那么,您 _可能_ 很幸运(如果梯度占用了很大部分内存并且不需要进行梯度累加)。我们将探讨以下内容:
在训练或微调循环中占用内存的内容,
如何捕捉和可视化内存快照以确定瓶颈,
新的
Tensor.register_post_accumulate_grad_hook(hook)
API,最终,如何将所有内容组合在10行代码中实现内存节省。
要运行本教程,您需要:
PyTorch 2.1.0 或更高版本,带有
torchvision
1 个 CUDA GPU 如果您希望在本地运行内存可视化。否则,此技术在任何设备上都表现相同。
现在让我们导入所需的模块和模型。我们将使用 torchvision 的视觉 Transformer 模型,但可以随意替换为自己的模型。我们还将使用 torch.optim.Adam
作为优化器,但是,您也可以随意替换为自己的优化器。
import torch
from torchvision import models
from pickle import dump
model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
Downloading: "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth" to /home/user/.cache/torch/hub/checkpoints/vit_l_16-852ce7e3.pth
0%| | 0.00/1.13G [00:00<?, ?B/s]
0%| | 128k/1.13G [00:00<58:49, 345kB/s]
0%| | 256k/1.13G [00:00<36:06, 562kB/s]
0%| | 384k/1.13G [00:00<26:48, 757kB/s]
0%| | 640k/1.13G [00:00<16:28, 1.23MB/s]
0%| | 1.00M/1.13G [00:00<10:18, 1.97MB/s]
0%| | 1.62M/1.13G [00:00<06:21, 3.19MB/s]
0%| | 2.62M/1.13G [00:01<03:52, 5.22MB/s]
0%| | 4.38M/1.13G [00:01<02:16, 8.90MB/s]
1%| | 6.75M/1.13G [00:01<01:30, 13.4MB/s]
1%| | 9.25M/1.13G [00:01<01:11, 16.8MB/s]
1%| | 11.4M/1.13G [00:01<01:05, 18.3MB/s]
1%|1 | 13.2M/1.13G [00:01<01:04, 18.7MB/s]
1%|1 | 15.5M/1.13G [00:01<00:59, 20.1MB/s]
2%|1 | 17.5M/1.13G [00:01<00:59, 20.3MB/s]
2%|1 | 19.5M/1.13G [00:01<00:59, 20.1MB/s]
2%|1 | 21.9M/1.13G [00:01<00:56, 21.3MB/s]
2%|2 | 24.1M/1.13G [00:02<00:54, 21.8MB/s]
2%|2 | 26.2M/1.13G [00:02<00:54, 21.9MB/s]
2%|2 | 28.5M/1.13G [00:02<00:54, 22.0MB/s]
3%|2 | 30.6M/1.13G [00:02<00:54, 21.8MB/s]
3%|2 | 32.9M/1.13G [00:02<00:53, 22.1MB/s]
3%|3 | 35.1M/1.13G [00:02<00:52, 22.4MB/s]
3%|3 | 37.6M/1.13G [00:02<00:50, 23.2MB/s]
3%|3 | 40.0M/1.13G [00:02<00:49, 23.6MB/s]
4%|3 | 42.4M/1.13G [00:02<00:49, 23.8MB/s]
4%|3 | 44.8M/1.13G [00:02<00:48, 24.1MB/s]
4%|4 | 47.1M/1.13G [00:03<00:48, 24.0MB/s]
4%|4 | 49.6M/1.13G [00:03<00:48, 24.2MB/s]
4%|4 | 52.1M/1.13G [00:03<00:47, 24.5MB/s]
5%|4 | 54.5M/1.13G [00:03<00:47, 24.5MB/s]
5%|4 | 56.9M/1.13G [00:03<00:47, 24.6MB/s]
5%|5 | 59.2M/1.13G [00:03<00:47, 24.3MB/s]
5%|5 | 61.6M/1.13G [00:03<00:47, 24.2MB/s]
6%|5 | 64.1M/1.13G [00:03<00:47, 24.4MB/s]
6%|5 | 66.6M/1.13G [00:03<00:46, 24.7MB/s]
6%|5 | 69.1M/1.13G [00:04<00:46, 24.7MB/s]
6%|6 | 71.5M/1.13G [00:04<00:46, 24.6MB/s]
6%|6 | 73.9M/1.13G [00:04<00:47, 24.0MB/s]
7%|6 | 76.2M/1.13G [00:04<00:47, 24.0MB/s]
7%|6 | 78.8M/1.13G [00:04<00:46, 24.6MB/s]
7%|6 | 81.1M/1.13G [00:04<00:46, 24.3MB/s]
7%|7 | 83.5M/1.13G [00:04<00:46, 24.4MB/s]
7%|7 | 85.9M/1.13G [00:04<00:46, 24.3MB/s]
8%|7 | 88.2M/1.13G [00:04<00:46, 24.4MB/s]
8%|7 | 90.8M/1.13G [00:04<00:45, 24.8MB/s]
8%|8 | 93.1M/1.13G [00:05<00:46, 24.3MB/s]
8%|8 | 95.5M/1.13G [00:05<00:45, 24.4MB/s]
8%|8 | 97.9M/1.13G [00:05<00:45, 24.4MB/s]
9%|8 | 100M/1.13G [00:05<00:46, 24.2MB/s]
9%|8 | 103M/1.13G [00:05<00:45, 24.5MB/s]
9%|9 | 105M/1.13G [00:05<00:45, 24.1MB/s]
9%|9 | 108M/1.13G [00:05<00:46, 23.8MB/s]
9%|9 | 110M/1.13G [00:05<00:47, 23.4MB/s]
10%|9 | 112M/1.13G [00:05<00:46, 23.6MB/s]
10%|9 | 115M/1.13G [00:06<00:45, 24.3MB/s]
10%|# | 117M/1.13G [00:06<00:45, 24.1MB/s]
10%|# | 120M/1.13G [00:06<00:44, 24.3MB/s]
10%|# | 122M/1.13G [00:06<00:45, 24.0MB/s]
11%|# | 124M/1.13G [00:06<00:45, 24.1MB/s]
11%|# | 127M/1.13G [00:06<00:44, 24.6MB/s]
11%|#1 | 129M/1.13G [00:06<00:43, 24.6MB/s]
11%|#1 | 132M/1.13G [00:06<00:44, 24.4MB/s]
12%|#1 | 134M/1.13G [00:06<00:44, 24.2MB/s]
12%|#1 | 136M/1.13G [00:06<00:43, 24.5MB/s]
12%|#1 | 139M/1.13G [00:07<00:43, 24.5MB/s]
12%|#2 | 141M/1.13G [00:07<00:43, 24.8MB/s]
12%|#2 | 144M/1.13G [00:07<00:43, 24.5MB/s]
13%|#2 | 146M/1.13G [00:07<00:43, 24.4MB/s]
13%|#2 | 148M/1.13G [00:07<00:43, 24.5MB/s]
13%|#2 | 151M/1.13G [00:07<00:43, 24.3MB/s]
13%|#3 | 154M/1.13G [00:07<00:42, 25.1MB/s]
13%|#3 | 156M/1.13G [00:07<00:42, 24.6MB/s]
14%|#3 | 158M/1.13G [00:07<00:42, 24.6MB/s]
14%|#3 | 161M/1.13G [00:07<00:42, 24.5MB/s]
14%|#4 | 163M/1.13G [00:08<00:42, 24.5MB/s]
14%|#4 | 166M/1.13G [00:08<00:43, 24.3MB/s]
15%|#4 | 168M/1.13G [00:08<00:41, 25.1MB/s]
15%|#4 | 171M/1.13G [00:08<00:41, 24.8MB/s]
15%|#4 | 173M/1.13G [00:08<00:42, 24.5MB/s]
15%|#5 | 176M/1.13G [00:08<00:42, 24.5MB/s]
15%|#5 | 178M/1.13G [00:08<00:42, 24.3MB/s]
16%|#5 | 181M/1.13G [00:08<00:42, 24.2MB/s]
16%|#5 | 183M/1.13G [00:08<00:40, 25.0MB/s]
16%|#5 | 186M/1.13G [00:09<00:41, 24.8MB/s]
16%|#6 | 188M/1.13G [00:09<00:41, 24.6MB/s]
16%|#6 | 190M/1.13G [00:09<00:41, 24.6MB/s]
17%|#6 | 193M/1.13G [00:09<00:41, 24.2MB/s]
17%|#6 | 196M/1.13G [00:09<00:41, 24.3MB/s]
17%|#7 | 198M/1.13G [00:09<00:40, 25.1MB/s]
17%|#7 | 201M/1.13G [00:09<00:40, 24.8MB/s]
17%|#7 | 203M/1.13G [00:09<00:41, 24.4MB/s]
18%|#7 | 206M/1.13G [00:09<00:40, 24.6MB/s]
18%|#7 | 208M/1.13G [00:09<00:41, 24.3MB/s]
18%|#8 | 210M/1.13G [00:10<00:40, 24.6MB/s]
18%|#8 | 213M/1.13G [00:10<00:39, 25.0MB/s]
19%|#8 | 215M/1.13G [00:10<00:40, 24.7MB/s]
19%|#8 | 218M/1.13G [00:10<00:40, 24.1MB/s]
19%|#8 | 220M/1.13G [00:10<00:42, 23.2MB/s]
19%|#9 | 222M/1.13G [00:10<00:44, 22.4MB/s]
19%|#9 | 225M/1.13G [00:10<00:43, 22.8MB/s]
20%|#9 | 227M/1.13G [00:10<00:43, 22.4MB/s]
20%|#9 | 229M/1.13G [00:10<00:43, 22.5MB/s]
20%|#9 | 232M/1.13G [00:11<00:43, 22.2MB/s]
20%|## | 234M/1.13G [00:11<00:43, 22.5MB/s]
20%|## | 236M/1.13G [00:11<00:43, 22.5MB/s]
21%|## | 238M/1.13G [00:11<00:42, 22.6MB/s]
21%|## | 240M/1.13G [00:11<00:43, 22.2MB/s]
21%|## | 243M/1.13G [00:11<00:42, 22.5MB/s]
21%|##1 | 245M/1.13G [00:11<00:41, 23.2MB/s]
21%|##1 | 248M/1.13G [00:11<00:40, 23.4MB/s]
22%|##1 | 250M/1.13G [00:11<00:40, 23.7MB/s]
22%|##1 | 252M/1.13G [00:12<00:40, 23.8MB/s]
22%|##1 | 255M/1.13G [00:12<00:39, 24.3MB/s]
22%|##2 | 257M/1.13G [00:12<00:38, 24.3MB/s]
22%|##2 | 260M/1.13G [00:12<00:38, 24.4MB/s]
23%|##2 | 262M/1.13G [00:12<00:38, 24.4MB/s]
23%|##2 | 264M/1.13G [00:12<00:38, 24.3MB/s]
23%|##2 | 267M/1.13G [00:12<00:38, 24.5MB/s]
23%|##3 | 269M/1.13G [00:12<00:38, 24.4MB/s]
23%|##3 | 272M/1.13G [00:12<00:38, 24.5MB/s]
24%|##3 | 274M/1.13G [00:12<00:37, 24.6MB/s]
24%|##3 | 277M/1.13G [00:13<00:37, 24.5MB/s]
24%|##4 | 279M/1.13G [00:13<00:38, 24.3MB/s]
24%|##4 | 281M/1.13G [00:13<00:38, 24.2MB/s]
24%|##4 | 284M/1.13G [00:13<00:37, 24.5MB/s]
25%|##4 | 286M/1.13G [00:13<00:37, 24.7MB/s]
25%|##4 | 289M/1.13G [00:13<00:37, 24.2MB/s]
25%|##5 | 291M/1.13G [00:13<00:37, 24.1MB/s]
25%|##5 | 294M/1.13G [00:13<00:37, 24.5MB/s]
25%|##5 | 296M/1.13G [00:13<00:37, 24.4MB/s]
26%|##5 | 298M/1.13G [00:13<00:36, 24.7MB/s]
26%|##5 | 301M/1.13G [00:14<00:36, 24.6MB/s]
26%|##6 | 303M/1.13G [00:14<00:36, 24.4MB/s]
26%|##6 | 306M/1.13G [00:14<00:36, 24.7MB/s]
27%|##6 | 308M/1.13G [00:14<00:36, 24.5MB/s]
27%|##6 | 310M/1.13G [00:14<00:36, 24.6MB/s]
27%|##6 | 313M/1.13G [00:14<00:36, 24.7MB/s]
27%|##7 | 315M/1.13G [00:14<00:36, 24.2MB/s]
27%|##7 | 318M/1.13G [00:15<01:16, 11.6MB/s]
28%|##7 | 320M/1.13G [00:15<01:04, 13.6MB/s]
28%|##7 | 322M/1.13G [00:15<00:55, 15.9MB/s]
28%|##7 | 325M/1.13G [00:15<00:49, 17.6MB/s]
28%|##8 | 327M/1.13G [00:15<00:45, 19.4MB/s]
28%|##8 | 330M/1.13G [00:15<00:42, 20.5MB/s]
29%|##8 | 332M/1.13G [00:15<00:40, 21.6MB/s]
29%|##8 | 334M/1.13G [00:15<00:38, 22.3MB/s]
29%|##9 | 337M/1.13G [00:16<00:38, 22.6MB/s]
29%|##9 | 339M/1.13G [00:16<00:37, 23.2MB/s]
29%|##9 | 342M/1.13G [00:16<00:36, 23.7MB/s]
30%|##9 | 344M/1.13G [00:16<00:35, 24.2MB/s]
30%|##9 | 346M/1.13G [00:16<00:35, 24.0MB/s]
30%|### | 349M/1.13G [00:16<00:35, 24.2MB/s]
30%|### | 351M/1.13G [00:16<00:35, 24.1MB/s]
30%|### | 354M/1.13G [00:16<00:34, 24.5MB/s]
31%|### | 356M/1.13G [00:16<00:34, 24.7MB/s]
31%|### | 359M/1.13G [00:16<00:34, 24.5MB/s]
31%|###1 | 361M/1.13G [00:17<00:34, 24.5MB/s]
31%|###1 | 363M/1.13G [00:17<00:34, 24.4MB/s]
32%|###1 | 366M/1.13G [00:17<00:34, 24.5MB/s]
32%|###1 | 368M/1.13G [00:17<00:33, 24.9MB/s]
32%|###1 | 371M/1.13G [00:17<00:33, 24.6MB/s]
32%|###2 | 373M/1.13G [00:17<00:33, 24.5MB/s]
32%|###2 | 375M/1.13G [00:17<00:33, 24.4MB/s]
33%|###2 | 378M/1.13G [00:17<00:33, 24.4MB/s]
33%|###2 | 380M/1.13G [00:17<00:33, 24.5MB/s]
33%|###2 | 382M/1.13G [00:17<00:33, 24.5MB/s]
33%|###3 | 385M/1.13G [00:18<00:33, 24.5MB/s]
33%|###3 | 387M/1.13G [00:18<00:33, 24.3MB/s]
34%|###3 | 390M/1.13G [00:18<00:33, 24.1MB/s]
34%|###3 | 392M/1.13G [00:18<00:32, 24.4MB/s]
34%|###3 | 394M/1.13G [00:18<00:33, 24.1MB/s]
34%|###4 | 397M/1.13G [00:18<00:33, 24.0MB/s]
34%|###4 | 399M/1.13G [00:18<00:32, 24.2MB/s]
35%|###4 | 402M/1.13G [00:18<00:32, 24.2MB/s]
35%|###4 | 404M/1.13G [00:18<00:32, 24.3MB/s]
35%|###5 | 407M/1.13G [00:19<00:32, 24.5MB/s]
35%|###5 | 409M/1.13G [00:19<00:32, 24.1MB/s]
35%|###5 | 412M/1.13G [00:19<00:32, 24.5MB/s]
36%|###5 | 414M/1.13G [00:19<00:31, 24.6MB/s]
36%|###5 | 416M/1.13G [00:19<00:31, 24.5MB/s]
36%|###6 | 419M/1.13G [00:19<00:31, 24.6MB/s]
36%|###6 | 421M/1.13G [00:19<00:32, 24.2MB/s]
36%|###6 | 424M/1.13G [00:19<00:31, 24.4MB/s]
37%|###6 | 426M/1.13G [00:19<00:31, 24.5MB/s]
37%|###6 | 428M/1.13G [00:19<00:31, 24.4MB/s]
37%|###7 | 431M/1.13G [00:20<00:31, 24.6MB/s]
37%|###7 | 433M/1.13G [00:20<00:31, 24.4MB/s]
38%|###7 | 436M/1.13G [00:20<00:31, 24.4MB/s]
38%|###7 | 438M/1.13G [00:20<00:30, 24.6MB/s]
38%|###7 | 440M/1.13G [00:20<00:30, 24.4MB/s]
38%|###8 | 443M/1.13G [00:20<00:31, 24.1MB/s]
38%|###8 | 445M/1.13G [00:20<00:32, 23.2MB/s]
39%|###8 | 448M/1.13G [00:20<00:31, 23.4MB/s]
39%|###8 | 450M/1.13G [00:20<00:31, 23.7MB/s]
39%|###8 | 452M/1.13G [00:20<00:31, 23.9MB/s]
39%|###9 | 455M/1.13G [00:21<00:31, 23.6MB/s]
39%|###9 | 457M/1.13G [00:21<00:31, 23.8MB/s]
40%|###9 | 460M/1.13G [00:21<00:30, 24.2MB/s]
40%|###9 | 462M/1.13G [00:21<00:29, 25.0MB/s]
40%|#### | 465M/1.13G [00:21<00:29, 24.5MB/s]
40%|#### | 467M/1.13G [00:21<00:29, 24.3MB/s]
40%|#### | 470M/1.13G [00:21<00:30, 23.9MB/s]
41%|#### | 472M/1.13G [00:21<00:30, 24.0MB/s]
41%|#### | 474M/1.13G [00:21<00:29, 24.5MB/s]
41%|####1 | 477M/1.13G [00:22<00:29, 24.6MB/s]
41%|####1 | 479M/1.13G [00:22<00:29, 24.4MB/s]
41%|####1 | 482M/1.13G [00:22<00:29, 24.4MB/s]
42%|####1 | 484M/1.13G [00:22<00:29, 24.4MB/s]
42%|####1 | 486M/1.13G [00:22<00:28, 24.7MB/s]
42%|####2 | 489M/1.13G [00:22<00:28, 24.6MB/s]
42%|####2 | 491M/1.13G [00:22<00:28, 24.6MB/s]
43%|####2 | 494M/1.13G [00:22<00:28, 24.5MB/s]
43%|####2 | 496M/1.13G [00:22<00:28, 24.2MB/s]
43%|####2 | 498M/1.13G [00:22<00:28, 24.6MB/s]
43%|####3 | 501M/1.13G [00:23<00:27, 24.8MB/s]
43%|####3 | 503M/1.13G [00:23<00:27, 24.6MB/s]
44%|####3 | 506M/1.13G [00:23<00:28, 24.4MB/s]
44%|####3 | 508M/1.13G [00:23<00:28, 23.9MB/s]
44%|####3 | 510M/1.13G [00:23<00:28, 24.0MB/s]
44%|####4 | 513M/1.13G [00:23<00:27, 24.5MB/s]
44%|####4 | 515M/1.13G [00:23<00:28, 23.8MB/s]
45%|####4 | 518M/1.13G [00:23<00:29, 23.0MB/s]
45%|####4 | 520M/1.13G [00:23<00:30, 21.9MB/s]
45%|####4 | 522M/1.13G [00:24<00:30, 21.9MB/s]
45%|####5 | 524M/1.13G [00:24<00:30, 21.8MB/s]
45%|####5 | 526M/1.13G [00:24<00:31, 21.3MB/s]
46%|####5 | 528M/1.13G [00:24<00:32, 20.5MB/s]
46%|####5 | 531M/1.13G [00:24<00:31, 21.1MB/s]
46%|####5 | 533M/1.13G [00:24<00:31, 21.2MB/s]
46%|####6 | 535M/1.13G [00:24<00:31, 20.6MB/s]
46%|####6 | 537M/1.13G [00:24<00:31, 20.5MB/s]
46%|####6 | 539M/1.13G [00:24<00:31, 20.7MB/s]
47%|####6 | 541M/1.13G [00:25<00:32, 20.2MB/s]
47%|####6 | 543M/1.13G [00:25<00:32, 19.6MB/s]
47%|####6 | 545M/1.13G [00:25<00:31, 20.2MB/s]
47%|####7 | 547M/1.13G [00:25<00:32, 20.1MB/s]
47%|####7 | 549M/1.13G [00:25<00:32, 19.7MB/s]
47%|####7 | 551M/1.13G [00:25<00:30, 20.6MB/s]
48%|####7 | 553M/1.13G [00:25<00:30, 20.6MB/s]
48%|####7 | 555M/1.13G [00:25<00:31, 20.4MB/s]
48%|####8 | 557M/1.13G [00:25<00:30, 20.5MB/s]
48%|####8 | 560M/1.13G [00:25<00:29, 21.1MB/s]
48%|####8 | 562M/1.13G [00:26<00:29, 21.0MB/s]
49%|####8 | 564M/1.13G [00:26<00:29, 21.4MB/s]
49%|####8 | 566M/1.13G [00:26<00:28, 22.2MB/s]
49%|####8 | 569M/1.13G [00:26<00:27, 22.4MB/s]
49%|####9 | 571M/1.13G [00:26<00:27, 22.8MB/s]
49%|####9 | 573M/1.13G [00:26<00:26, 23.0MB/s]
50%|####9 | 576M/1.13G [00:26<00:26, 23.4MB/s]
50%|####9 | 578M/1.13G [00:26<00:25, 23.8MB/s]
50%|##### | 581M/1.13G [00:26<00:25, 24.3MB/s]
50%|##### | 583M/1.13G [00:27<00:54, 11.1MB/s]
50%|##### | 585M/1.13G [00:27<00:45, 13.2MB/s]
51%|##### | 588M/1.13G [00:27<00:39, 15.4MB/s]
51%|##### | 590M/1.13G [00:27<00:34, 17.2MB/s]
51%|#####1 | 592M/1.13G [00:27<00:31, 18.9MB/s]
51%|#####1 | 595M/1.13G [00:27<00:29, 20.3MB/s]
51%|#####1 | 597M/1.13G [00:28<00:27, 21.4MB/s]
52%|#####1 | 600M/1.13G [00:28<00:26, 22.2MB/s]
52%|#####1 | 602M/1.13G [00:28<00:25, 22.6MB/s]
52%|#####2 | 604M/1.13G [00:28<00:25, 23.1MB/s]
52%|#####2 | 607M/1.13G [00:28<00:25, 23.0MB/s]
52%|#####2 | 609M/1.13G [00:28<00:25, 22.6MB/s]
53%|#####2 | 611M/1.13G [00:28<00:25, 22.4MB/s]
53%|#####2 | 614M/1.13G [00:28<00:26, 21.8MB/s]
53%|#####3 | 616M/1.13G [00:28<00:26, 21.4MB/s]
53%|#####3 | 618M/1.13G [00:29<00:26, 21.1MB/s]
53%|#####3 | 620M/1.13G [00:29<00:27, 20.8MB/s]
54%|#####3 | 622M/1.13G [00:29<00:27, 20.9MB/s]
54%|#####3 | 624M/1.13G [00:29<00:27, 20.7MB/s]
54%|#####3 | 626M/1.13G [00:29<00:26, 21.4MB/s]
54%|#####4 | 629M/1.13G [00:29<00:25, 21.8MB/s]
54%|#####4 | 631M/1.13G [00:29<00:24, 22.3MB/s]
55%|#####4 | 633M/1.13G [00:29<00:24, 22.3MB/s]
55%|#####4 | 636M/1.13G [00:29<00:24, 22.5MB/s]
55%|#####4 | 638M/1.13G [00:29<00:24, 22.3MB/s]
55%|#####5 | 640M/1.13G [00:30<00:25, 21.8MB/s]
55%|#####5 | 642M/1.13G [00:30<00:24, 22.1MB/s]
56%|#####5 | 644M/1.13G [00:30<00:25, 21.3MB/s]
56%|#####5 | 646M/1.13G [00:30<00:25, 20.8MB/s]
56%|#####5 | 649M/1.13G [00:30<00:25, 21.1MB/s]
56%|#####6 | 651M/1.13G [00:30<00:25, 21.1MB/s]
56%|#####6 | 653M/1.13G [00:30<00:25, 20.8MB/s]
56%|#####6 | 655M/1.13G [00:30<00:25, 21.2MB/s]
57%|#####6 | 658M/1.13G [00:30<00:24, 21.6MB/s]
57%|#####6 | 660M/1.13G [00:31<00:24, 21.2MB/s]
57%|#####6 | 662M/1.13G [00:31<00:25, 20.5MB/s]
57%|#####7 | 664M/1.13G [00:31<00:25, 20.4MB/s]
57%|#####7 | 666M/1.13G [00:31<00:24, 20.9MB/s]
58%|#####7 | 668M/1.13G [00:31<00:25, 20.3MB/s]
58%|#####7 | 670M/1.13G [00:31<00:25, 20.6MB/s]
58%|#####7 | 672M/1.13G [00:31<00:24, 21.2MB/s]
58%|#####8 | 675M/1.13G [00:31<00:23, 22.1MB/s]
58%|#####8 | 677M/1.13G [00:31<00:23, 22.0MB/s]
58%|#####8 | 679M/1.13G [00:31<00:23, 21.7MB/s]
59%|#####8 | 681M/1.13G [00:32<00:22, 21.9MB/s]
59%|#####8 | 684M/1.13G [00:32<00:22, 22.7MB/s]
59%|#####9 | 686M/1.13G [00:32<00:22, 21.7MB/s]
59%|#####9 | 688M/1.13G [00:32<00:23, 20.9MB/s]
59%|#####9 | 690M/1.13G [00:32<00:23, 21.3MB/s]
60%|#####9 | 692M/1.13G [00:32<00:22, 21.7MB/s]
60%|#####9 | 695M/1.13G [00:32<00:23, 20.9MB/s]
60%|###### | 697M/1.13G [00:32<00:23, 20.8MB/s]
60%|###### | 699M/1.13G [00:32<00:22, 21.5MB/s]
60%|###### | 701M/1.13G [00:33<00:21, 22.1MB/s]
61%|###### | 704M/1.13G [00:33<00:22, 21.6MB/s]
61%|###### | 706M/1.13G [00:33<00:21, 21.8MB/s]
61%|###### | 708M/1.13G [00:33<00:21, 22.5MB/s]
61%|######1 | 711M/1.13G [00:33<00:20, 23.4MB/s]
61%|######1 | 713M/1.13G [00:33<00:20, 23.2MB/s]
62%|######1 | 715M/1.13G [00:33<00:20, 22.6MB/s]
62%|######1 | 718M/1.13G [00:33<00:20, 23.1MB/s]
62%|######2 | 720M/1.13G [00:33<00:19, 23.6MB/s]
62%|######2 | 722M/1.13G [00:34<00:19, 23.6MB/s]
62%|######2 | 725M/1.13G [00:34<00:19, 23.4MB/s]
63%|######2 | 727M/1.13G [00:34<00:20, 22.7MB/s]
63%|######2 | 729M/1.13G [00:34<00:20, 22.3MB/s]
63%|######2 | 731M/1.13G [00:34<00:20, 22.4MB/s]
63%|######3 | 734M/1.13G [00:34<00:20, 22.2MB/s]
63%|######3 | 736M/1.13G [00:34<00:20, 21.9MB/s]
64%|######3 | 738M/1.13G [00:34<00:21, 21.1MB/s]
64%|######3 | 740M/1.13G [00:34<00:21, 20.7MB/s]
64%|######3 | 742M/1.13G [00:35<00:22, 19.8MB/s]
64%|######4 | 744M/1.13G [00:35<00:22, 19.3MB/s]
64%|######4 | 746M/1.13G [00:35<00:22, 19.0MB/s]
64%|######4 | 748M/1.13G [00:35<00:23, 18.8MB/s]
65%|######4 | 750M/1.13G [00:35<00:22, 18.9MB/s]
65%|######4 | 752M/1.13G [00:35<00:20, 20.5MB/s]
65%|######4 | 754M/1.13G [00:35<00:20, 20.7MB/s]
65%|######5 | 756M/1.13G [00:35<00:21, 20.1MB/s]
65%|######5 | 758M/1.13G [00:35<00:20, 21.1MB/s]
66%|######5 | 761M/1.13G [00:35<00:19, 21.4MB/s]
66%|######5 | 763M/1.13G [00:36<00:20, 20.6MB/s]
66%|######5 | 765M/1.13G [00:36<00:20, 20.6MB/s]
66%|######6 | 767M/1.13G [00:36<00:19, 21.7MB/s]
66%|######6 | 769M/1.13G [00:36<00:19, 20.8MB/s]
66%|######6 | 771M/1.13G [00:36<00:19, 20.9MB/s]
67%|######6 | 774M/1.13G [00:36<00:19, 21.0MB/s]
67%|######6 | 776M/1.13G [00:36<00:18, 21.4MB/s]
67%|######6 | 778M/1.13G [00:36<00:18, 21.6MB/s]
67%|######7 | 780M/1.13G [00:36<00:17, 22.3MB/s]
67%|######7 | 782M/1.13G [00:37<00:17, 22.4MB/s]
68%|######7 | 785M/1.13G [00:37<00:17, 22.7MB/s]
68%|######7 | 787M/1.13G [00:37<00:17, 22.5MB/s]
68%|######7 | 789M/1.13G [00:37<00:17, 22.2MB/s]
68%|######8 | 791M/1.13G [00:37<00:17, 22.1MB/s]
68%|######8 | 794M/1.13G [00:37<00:17, 22.2MB/s]
69%|######8 | 796M/1.13G [00:37<00:17, 21.6MB/s]
69%|######8 | 798M/1.13G [00:37<00:17, 21.5MB/s]
69%|######8 | 800M/1.13G [00:37<00:17, 21.7MB/s]
69%|######9 | 802M/1.13G [00:37<00:17, 21.3MB/s]
69%|######9 | 804M/1.13G [00:38<00:18, 20.7MB/s]
69%|######9 | 806M/1.13G [00:38<00:36, 10.2MB/s]
70%|######9 | 809M/1.13G [00:38<00:29, 12.6MB/s]
70%|######9 | 811M/1.13G [00:38<00:25, 14.6MB/s]
70%|####### | 813M/1.13G [00:38<00:22, 16.1MB/s]
70%|####### | 815M/1.13G [00:38<00:20, 17.8MB/s]
70%|####### | 818M/1.13G [00:39<00:18, 19.0MB/s]
71%|####### | 820M/1.13G [00:39<00:17, 20.2MB/s]
71%|####### | 822M/1.13G [00:39<00:17, 20.7MB/s]
71%|####### | 824M/1.13G [00:39<00:17, 20.8MB/s]
71%|#######1 | 826M/1.13G [00:39<00:16, 20.7MB/s]
71%|#######1 | 828M/1.13G [00:39<00:16, 20.9MB/s]
72%|#######1 | 831M/1.13G [00:39<00:15, 21.7MB/s]
72%|#######1 | 833M/1.13G [00:39<00:15, 22.0MB/s]
72%|#######1 | 835M/1.13G [00:39<00:15, 21.7MB/s]
72%|#######2 | 837M/1.13G [00:40<00:16, 21.0MB/s]
72%|#######2 | 839M/1.13G [00:40<00:15, 21.3MB/s]
72%|#######2 | 842M/1.13G [00:40<00:15, 21.2MB/s]
73%|#######2 | 844M/1.13G [00:40<00:16, 20.0MB/s]
73%|#######2 | 846M/1.13G [00:40<00:16, 19.5MB/s]
73%|#######3 | 848M/1.13G [00:40<00:16, 19.8MB/s]
73%|#######3 | 850M/1.13G [00:40<00:17, 18.6MB/s]
73%|#######3 | 852M/1.13G [00:40<00:17, 18.7MB/s]
74%|#######3 | 854M/1.13G [00:40<00:15, 20.2MB/s]
74%|#######3 | 856M/1.13G [00:41<00:16, 19.4MB/s]
74%|#######3 | 858M/1.13G [00:41<00:17, 18.6MB/s]
74%|#######4 | 860M/1.13G [00:41<00:16, 19.5MB/s]
74%|#######4 | 862M/1.13G [00:41<00:16, 19.1MB/s]
74%|#######4 | 864M/1.13G [00:41<00:17, 18.0MB/s]
75%|#######4 | 866M/1.13G [00:41<00:15, 19.4MB/s]
75%|#######4 | 868M/1.13G [00:41<00:15, 19.6MB/s]
75%|#######4 | 870M/1.13G [00:41<00:15, 19.3MB/s]
75%|#######5 | 872M/1.13G [00:41<00:14, 20.4MB/s]
75%|#######5 | 874M/1.13G [00:42<00:14, 20.6MB/s]
75%|#######5 | 876M/1.13G [00:42<00:14, 20.4MB/s]
76%|#######5 | 878M/1.13G [00:42<00:14, 20.2MB/s]
76%|#######5 | 881M/1.13G [00:42<00:14, 21.0MB/s]
76%|#######6 | 883M/1.13G [00:42<00:14, 20.3MB/s]
76%|#######6 | 885M/1.13G [00:42<00:14, 20.3MB/s]
76%|#######6 | 887M/1.13G [00:42<00:13, 20.9MB/s]
77%|#######6 | 889M/1.13G [00:42<00:13, 21.4MB/s]
77%|#######6 | 892M/1.13G [00:42<00:13, 20.7MB/s]
77%|#######6 | 894M/1.13G [00:43<00:13, 20.4MB/s]
77%|#######7 | 896M/1.13G [00:43<00:12, 21.6MB/s]
77%|#######7 | 898M/1.13G [00:43<00:12, 21.3MB/s]
78%|#######7 | 900M/1.13G [00:43<00:13, 21.0MB/s]
78%|#######7 | 902M/1.13G [00:43<00:13, 20.5MB/s]
78%|#######7 | 905M/1.13G [00:43<00:12, 21.9MB/s]
78%|#######8 | 907M/1.13G [00:43<00:12, 20.9MB/s]
78%|#######8 | 909M/1.13G [00:43<00:12, 21.6MB/s]
78%|#######8 | 911M/1.13G [00:43<00:12, 21.7MB/s]
79%|#######8 | 914M/1.13G [00:43<00:11, 22.4MB/s]
79%|#######8 | 916M/1.13G [00:44<00:11, 21.9MB/s]
79%|#######9 | 918M/1.13G [00:44<00:11, 21.3MB/s]
79%|#######9 | 920M/1.13G [00:44<00:12, 20.9MB/s]
79%|#######9 | 922M/1.13G [00:44<00:12, 20.8MB/s]
80%|#######9 | 924M/1.13G [00:44<00:11, 21.1MB/s]
80%|#######9 | 926M/1.13G [00:44<00:11, 21.4MB/s]
80%|#######9 | 929M/1.13G [00:44<00:11, 21.1MB/s]
80%|######## | 931M/1.13G [00:44<00:11, 20.9MB/s]
80%|######## | 933M/1.13G [00:44<00:11, 21.5MB/s]
81%|######## | 935M/1.13G [00:45<00:11, 20.6MB/s]
81%|######## | 937M/1.13G [00:45<00:11, 19.6MB/s]
81%|######## | 939M/1.13G [00:45<00:11, 20.2MB/s]
81%|########1 | 941M/1.13G [00:45<00:11, 20.1MB/s]
81%|########1 | 943M/1.13G [00:45<00:11, 19.5MB/s]
81%|########1 | 946M/1.13G [00:45<00:11, 20.2MB/s]
82%|########1 | 948M/1.13G [00:45<00:10, 21.7MB/s]
82%|########1 | 950M/1.13G [00:45<00:10, 21.4MB/s]
82%|########2 | 952M/1.13G [00:45<00:10, 21.6MB/s]
82%|########2 | 955M/1.13G [00:46<00:09, 22.4MB/s]
82%|########2 | 957M/1.13G [00:46<00:09, 23.2MB/s]
83%|########2 | 959M/1.13G [00:46<00:09, 22.9MB/s]
83%|########2 | 962M/1.13G [00:46<00:09, 22.8MB/s]
83%|########3 | 964M/1.13G [00:46<00:09, 22.2MB/s]
83%|########3 | 966M/1.13G [00:46<00:08, 23.0MB/s]
83%|########3 | 969M/1.13G [00:46<00:08, 23.3MB/s]
84%|########3 | 971M/1.13G [00:46<00:08, 23.0MB/s]
84%|########3 | 973M/1.13G [00:46<00:09, 21.8MB/s]
84%|########4 | 976M/1.13G [00:46<00:08, 22.0MB/s]
84%|########4 | 978M/1.13G [00:47<00:08, 23.4MB/s]
84%|########4 | 980M/1.13G [00:47<00:08, 23.4MB/s]
85%|########4 | 983M/1.13G [00:47<00:08, 23.3MB/s]
85%|########4 | 985M/1.13G [00:47<00:07, 23.3MB/s]
85%|########5 | 988M/1.13G [00:47<00:07, 24.0MB/s]
85%|########5 | 990M/1.13G [00:47<00:07, 24.9MB/s]
85%|########5 | 993M/1.13G [00:47<00:07, 24.5MB/s]
86%|########5 | 995M/1.13G [00:47<00:07, 23.6MB/s]
86%|########5 | 997M/1.13G [00:47<00:07, 23.2MB/s]
86%|########6 | 0.98G/1.13G [00:48<00:07, 23.6MB/s]
86%|########6 | 0.98G/1.13G [00:48<00:06, 24.2MB/s]
87%|########6 | 0.98G/1.13G [00:48<00:06, 25.0MB/s]
87%|########6 | 0.98G/1.13G [00:48<00:06, 24.4MB/s]
87%|########6 | 0.99G/1.13G [00:48<00:06, 23.5MB/s]
87%|########7 | 0.99G/1.13G [00:48<00:13, 11.4MB/s]
87%|########7 | 0.99G/1.13G [00:49<00:11, 13.5MB/s]
88%|########7 | 0.99G/1.13G [00:49<00:09, 15.6MB/s]
88%|########7 | 1.00G/1.13G [00:49<00:08, 17.5MB/s]
88%|########7 | 1.00G/1.13G [00:49<00:07, 19.2MB/s]
88%|########8 | 1.00G/1.13G [00:49<00:07, 20.0MB/s]
88%|########8 | 1.00G/1.13G [00:49<00:06, 20.9MB/s]
89%|########8 | 1.00G/1.13G [00:49<00:06, 21.5MB/s]
89%|########8 | 1.01G/1.13G [00:49<00:06, 21.8MB/s]
89%|########8 | 1.01G/1.13G [00:49<00:06, 21.9MB/s]
89%|########9 | 1.01G/1.13G [00:50<00:05, 22.2MB/s]
89%|########9 | 1.01G/1.13G [00:50<00:05, 22.5MB/s]
90%|########9 | 1.02G/1.13G [00:50<00:05, 23.0MB/s]
90%|########9 | 1.02G/1.13G [00:50<00:05, 23.7MB/s]
90%|########9 | 1.02G/1.13G [00:50<00:05, 23.8MB/s]
90%|######### | 1.02G/1.13G [00:50<00:05, 23.9MB/s]
90%|######### | 1.02G/1.13G [00:50<00:04, 24.0MB/s]
91%|######### | 1.03G/1.13G [00:50<00:04, 23.6MB/s]
91%|######### | 1.03G/1.13G [00:50<00:04, 23.7MB/s]
91%|######### | 1.03G/1.13G [00:50<00:04, 24.0MB/s]
91%|#########1| 1.03G/1.13G [00:51<00:04, 23.8MB/s]
91%|#########1| 1.04G/1.13G [00:51<00:04, 23.5MB/s]
92%|#########1| 1.04G/1.13G [00:51<00:04, 22.8MB/s]
92%|#########1| 1.04G/1.13G [00:51<00:04, 23.3MB/s]
92%|#########1| 1.04G/1.13G [00:51<00:04, 23.0MB/s]
92%|#########2| 1.05G/1.13G [00:51<00:04, 23.4MB/s]
92%|#########2| 1.05G/1.13G [00:51<00:04, 23.1MB/s]
93%|#########2| 1.05G/1.13G [00:51<00:03, 23.5MB/s]
93%|#########2| 1.05G/1.13G [00:51<00:03, 23.3MB/s]
93%|#########3| 1.05G/1.13G [00:51<00:03, 23.7MB/s]
93%|#########3| 1.06G/1.13G [00:52<00:03, 23.8MB/s]
93%|#########3| 1.06G/1.13G [00:52<00:03, 24.1MB/s]
94%|#########3| 1.06G/1.13G [00:52<00:03, 24.7MB/s]
94%|#########3| 1.06G/1.13G [00:52<00:03, 24.3MB/s]
94%|#########4| 1.07G/1.13G [00:52<00:02, 24.3MB/s]
94%|#########4| 1.07G/1.13G [00:52<00:02, 24.4MB/s]
94%|#########4| 1.07G/1.13G [00:52<00:02, 24.4MB/s]
95%|#########4| 1.07G/1.13G [00:52<00:02, 24.8MB/s]
95%|#########4| 1.08G/1.13G [00:52<00:02, 24.4MB/s]
95%|#########5| 1.08G/1.13G [00:53<00:02, 24.6MB/s]
95%|#########5| 1.08G/1.13G [00:53<00:02, 24.6MB/s]
95%|#########5| 1.08G/1.13G [00:53<00:02, 24.5MB/s]
96%|#########5| 1.09G/1.13G [00:53<00:02, 24.6MB/s]
96%|#########5| 1.09G/1.13G [00:53<00:02, 24.4MB/s]
96%|#########6| 1.09G/1.13G [00:53<00:01, 24.1MB/s]
96%|#########6| 1.09G/1.13G [00:53<00:01, 24.3MB/s]
97%|#########6| 1.09G/1.13G [00:53<00:01, 24.1MB/s]
97%|#########6| 1.10G/1.13G [00:53<00:01, 24.0MB/s]
97%|#########6| 1.10G/1.13G [00:53<00:01, 24.1MB/s]
97%|#########7| 1.10G/1.13G [00:54<00:01, 24.1MB/s]
97%|#########7| 1.10G/1.13G [00:54<00:01, 24.1MB/s]
98%|#########7| 1.11G/1.13G [00:54<00:01, 23.1MB/s]
98%|#########7| 1.11G/1.13G [00:54<00:01, 23.4MB/s]
98%|#########7| 1.11G/1.13G [00:54<00:01, 23.3MB/s]
98%|#########8| 1.11G/1.13G [00:54<00:00, 23.7MB/s]
98%|#########8| 1.12G/1.13G [00:54<00:00, 23.1MB/s]
99%|#########8| 1.12G/1.13G [00:54<00:00, 22.6MB/s]
99%|#########8| 1.12G/1.13G [00:54<00:00, 22.4MB/s]
99%|#########8| 1.12G/1.13G [00:55<00:00, 23.3MB/s]
99%|#########9| 1.12G/1.13G [00:55<00:00, 24.2MB/s]
99%|#########9| 1.13G/1.13G [00:55<00:00, 24.2MB/s]
100%|#########9| 1.13G/1.13G [00:55<00:00, 23.9MB/s]
100%|#########9| 1.13G/1.13G [00:55<00:00, 23.6MB/s]
100%|##########| 1.13G/1.13G [00:55<00:00, 21.9MB/s]
现在定义典型的训练循环。在训练时您应使用真实图像,但出于本教程的目的,我们传入假输入并不考虑加载任何实际数据。
IMAGE_SIZE = 224
def train(model, optimizer):
# create our fake image input: tensor shape is batch_size, channels, height, width
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
# call our forward and backward
loss = model.forward(fake_image)
loss.sum().backward()
# optimizer update
optimizer.step()
optimizer.zero_grad()
训练期间的内存使用情况¶
我们将查看一些内存快照,所以我们应该准备好正确地分析它们。通常,训练内存由以下组成:
模型参数(大小为 P)
为了反向传播而保存的激活(大小为 A)
梯度,大小与模型参数相同,因此 G = P。
优化器状态,与参数大小成比例。在这里,使用 Adam 优化器需要 2 倍于模型参数的状态,因此 O = 2P。
中间张量,它们会在计算过程中分配。我们暂时不考虑它们,因为它们通常很小且是短暂的。
捕获和可视化内存快照¶
让我们获取一个内存快照!当您的代码运行时,考虑您可能期望的 CUDA 内存时间线。
# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')
# train 3 steps
for _ in range(3):
train(model, optimizer)
# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot.pickle", "wb") as f:
dump(s, f)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
现在通过拖放 snapshot.pickle
文件,在 https://pytorch.org/memory_viz 的 CUDA 内存可视化工具中打开快照。内存时间线是否符合您的预期?

在训练步骤之前,模型参数已经加载到内存中,因此我们立即看到一块专用于权重的内存块。随着我们开始前向传播,内存逐渐被分配用于激活,即为了能够在反向传播中计算梯度而保存的张量。一旦开始反向传播,激活逐渐被释放,同时梯度的内存开始积累。
最后,当优化器启动时,其状态将被延迟初始化,因此我们应该看到优化器状态内存在第一次训练循环的优化步骤中逐渐增加。在后续循环中,优化器内存会保持不变,并就地更新。梯度内存随后在每次训练循环结束时调用 zero_grad
时被相应释放。
这个训练循环中的内存瓶颈在哪里?或者换句话说,内存峰值在哪里?
内存峰值是在优化器步骤期间!注意,此时内存包括约 1.2GB 的参数、约 1.2GB 的梯度和约 2.4GB=2*1.2GB 的优化器状态,这符合预期。最后的约 1.2GB 来自 Adam 优化器需要中间值的内存,总计约 6GB 的峰值内存。从技术上讲,如果您设置 Adam(model.parameters(), foreach=False)
,则可以消除最后 1.2GB 优化器中间值的需求,从而以牺牲运行时优化换取内存节省。如果关闭 foreach
运行时优化能够为您节约足够的内存,那很好,但如果您好奇这个教程可以帮助您做得更好,请继续阅读!通过我们将要介绍的技术,我们将通过消除约 1.2GB 的梯度内存以及优化器中间值内存来减少峰值内存。那么,您认为新的内存峰值会是什么?答案将在 next 快照中揭晓。
免责声明:这种技术**并非**适用于所有情况¶
在我们过于兴奋之前,我们必须考虑这种技术是否适用于您的用例。这并不是万能之策!将优化器步骤融合到反向传播中的技术只针对于减少*梯度*内存(以及其副作用:优化器中间值内存)的目的。因此,梯度占用的内存越大,内存减少就越显著。在我们上面的示例中,梯度占用了 20% 的内存比例,这相当可观!
这可能不适合您,例如,如果您的权重已经很小(例如,由于应用了 LoRa),那么梯度在您的训练循环中所占空间很少,收益就少得多。在这种情况下,您应该首先尝试其他技术,如激活检查点、分布式训练、量化或减小批量大小。然后,当梯度再次成为瓶颈时,请返回本教程!
仍然在这里?很好,让我们介绍新的 register_post_accumulate_grad_hook(hook)
API 到张量。
Tensor.register_post_accumulate_grad_hook(hook)
API 和我们的技术¶
我们的技术依赖于在 backward()
期间不保存梯度。相反,一旦梯度累积完成,我们将立即对对应参数应用优化器,并完全丢弃该梯度!这样就不需要在优化器步骤之前保存一个大的梯度缓冲区。
那么我们如何实现更急切地应用优化器呢?在我们的 2.1 版本中,我们添加了一个新的 API torch.Tensor.register_post_accumulate_grad_hook()
,允许我们在张量的 .grad
字段累积后附加一个钩子。我们将在这个钩子中封装优化器步骤。如何做?
如何用 10 行代码实现所有组合¶
还记得我们之前的模型和优化器设置吗?我会将它们注释掉放在下面,这样我们就不需要重新运行代码。
model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
# for every parameter so we could reference them in our hook.
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}
# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
def optimizer_hook(parameter) -> None:
optimizer_dict[parameter].step()
optimizer_dict[parameter].zero_grad()
# Register the hook onto every parameter
for p in model.parameters():
p.register_post_accumulate_grad_hook(optimizer_hook)
# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
# create our fake image input: tensor shape is batch_size, channels, height, width
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
# call our forward and backward
loss = model.forward(fake_image)
loss.sum().backward()
# optimizer update --> no longer needed!
# optimizer.step()
# optimizer.zero_grad()
在我们的示例模型中,这大约需要 10 行代码更改,挺整洁。然而,对于真实模型,要将优化器切换为优化器字典可能会是一项相当侵入的变更,尤其是对于那些使用 LRScheduler
或在训练周期中操控优化器配置的人来说。这些变化可能更复杂,并可能需要将更多配置移到全局状态,但不应该是不可能的。话虽如此,PyTorch 的下一步是使该 API 更易于与 LRScheduler 和您已经习惯的其他功能一起采用。
但让我回到说服您这种技术值得实施的问题。我们将咨询我们的朋友,内存快照。
# delete optimizer memory from before to get a clean slate for the next
# memory snapshot
del optimizer
# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')
# train 3 steps. note that we no longer pass the optimizer into train()
for _ in range(3):
train(model)
# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot-opt-in-bwd.pickle", "wb") as f:
dump(s, f)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
是的,花点时间将您的快照拖到 CUDA 内存可视化工具中。

- 几个主要观察结果:
不再有优化器步骤了!没错……我们将其融合到反向传播中。
同样,反向传播拖得更长,并且出现了更多随机的中间值分配。这是可以预料的,因为优化器步骤需要中间值。
最重要的是!峰值内存更低了!现在大约是 ~4GB(希望这与您之前的预期紧密匹配)。
注意,与之前相比,内存中不再有为梯度分配的大块内存,这节约了约 1.2GB 的内存。取而代之的是,我们通过尽可能提前移动优化器步骤,非常快地释放每个计算出的梯度。太棒了!顺便说一下,另约 1.2GB 的内存节约来自于将优化器拆分为每参数优化器,因此中间值比例减少了。这一细节`不如`梯度内存节约重要,因为您可以仅通过设置 foreach=False
而无需此技术来获得优化器中间值节约。
您可能很正确地在思考:如果我们节省了 2.4GB 的内存,为什么峰值内存不是 6GB - 2.4GB = 3.6GB?嗯,峰值已经移动了!峰值现在几乎出现在反向传播步骤开始时,当我们仍然有激活在内存中,而之前峰值是在优化器步骤期间,此时激活已经被释放。大约 ~0.4GB 的差异,约 ~4.0GB - ~3.6GB,由激活内存引起。可以想象,此技术可以与激活检查点相结合以获得更多内存收益。
结论¶
在本教程中,我们学习了通过新的 Tensor.register_post_accumulate_grad_hook()
API 将优化器融合到反向传播步骤中的内存节约技术,以及*何时*应用该技术(当梯度内存占用显著时)。同时,我们还学习了内存快照,它在内存优化中通常非常有用。
脚本总运行时间: ( 1 分钟 3.554 秒)