Shortcuts

使用 ZeroRedundancyOptimizer 分片优化器状态

Created On: Feb 26, 2021 | Last Updated: Oct 20, 2021 | Last Verified: Not Verified

在此教程中,您将学习:

要求

什么是 ZeroRedundancyOptimizer

ZeroRedundancyOptimizer 的理念来自 DeepSpeed/ZeRO 项目Marian,它通过在分布式数据并行进程间分片优化器状态来减少每个进程的内存占用。在 分布式数据并行入门 教程中,我们展示了如何使用 DistributedDataParallel (DDP) 来训练模型。在该教程中,每个进程保留一个独立的优化器副本。由于 DDP 已在反向传播过程中同步了梯度,因此每次迭代中所有优化器副本都将对相同的参数和梯度值进行操作,这就是 DDP 确保模型副本状态一致的方式。通常情况下,优化器还会维护本地状态。例如,Adam 优化器使用逐参数的 exp_avgexp_avg_sq 状态。因此,Adam 优化器的内存消耗至少是模型大小的两倍。鉴于此,我们可以通过在 DDP 进程间分片优化器状态来减少优化器的内存占用。更具体地说,不再为所有参数创建逐参数状态,而是每个 DDP 进程中的优化器实例仅保留一部分模型参数的优化器状态。优化器的 step() 函数只更新其分片中的参数,然后将更新后的参数广播到所有其他 DDP 对等进程,以确保所有模型副本仍保持一致状态。

如何使用 ZeroRedundancyOptimizer

下面的代码展示了如何使用 ZeroRedundancyOptimizer。大部分代码与 分布式数据并行注释 中提供的简单 DDP 示例类似。主要区别在于 example 函数中的 if-else 子句,它封装了优化器的构造,在 ZeroRedundancyOptimizerAdam 优化器之间切换。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

下面是输出结果。当启用 ZeroRedundancyOptimizerAdam 时,优化器 step() 的峰值内存消耗是普通 Adam 的一半。这符合我们的预期,因为我们在两个进程之间分片了 Adam 优化器状态。输出还显示,即使使用 ZeroRedundancyOptimizer,模型参数在一次迭代后仍保持相同的值(使用和不使用 ZeroRedundancyOptimizer 的参数总和相同)。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源