使用 ZeroRedundancyOptimizer 分片优化器状态¶
Created On: Feb 26, 2021 | Last Updated: Oct 20, 2021 | Last Verified: Not Verified
在此教程中,您将学习:
ZeroRedundancyOptimizer 的高级理念。
如何在分布式训练中使用 ZeroRedundancyOptimizer 以及其影响。
什么是 ZeroRedundancyOptimizer
?¶
ZeroRedundancyOptimizer 的理念来自 DeepSpeed/ZeRO 项目 和 Marian,它通过在分布式数据并行进程间分片优化器状态来减少每个进程的内存占用。在 分布式数据并行入门 教程中,我们展示了如何使用 DistributedDataParallel (DDP) 来训练模型。在该教程中,每个进程保留一个独立的优化器副本。由于 DDP 已在反向传播过程中同步了梯度,因此每次迭代中所有优化器副本都将对相同的参数和梯度值进行操作,这就是 DDP 确保模型副本状态一致的方式。通常情况下,优化器还会维护本地状态。例如,Adam
优化器使用逐参数的 exp_avg
和 exp_avg_sq
状态。因此,Adam
优化器的内存消耗至少是模型大小的两倍。鉴于此,我们可以通过在 DDP 进程间分片优化器状态来减少优化器的内存占用。更具体地说,不再为所有参数创建逐参数状态,而是每个 DDP 进程中的优化器实例仅保留一部分模型参数的优化器状态。优化器的 step()
函数只更新其分片中的参数,然后将更新后的参数广播到所有其他 DDP 对等进程,以确保所有模型副本仍保持一致状态。
如何使用 ZeroRedundancyOptimizer
?¶
下面的代码展示了如何使用 ZeroRedundancyOptimizer。大部分代码与 分布式数据并行注释 中提供的简单 DDP 示例类似。主要区别在于 example
函数中的 if-else
子句,它封装了优化器的构造,在 ZeroRedundancyOptimizer 和 Adam
优化器之间切换。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
def print_peak_memory(prefix, device):
if device == 0:
print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")
def example(rank, world_size, use_zero):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
print_peak_memory("Max memory allocated after creating local model", rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
print_peak_memory("Max memory allocated after creating DDP", rank)
# define loss function and optimizer
loss_fn = nn.MSELoss()
if use_zero:
optimizer = ZeroRedundancyOptimizer(
ddp_model.parameters(),
optimizer_class=torch.optim.Adam,
lr=0.01
)
else:
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
# forward pass
outputs = ddp_model(torch.randn(20, 2000).to(rank))
labels = torch.randn(20, 2000).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
print_peak_memory("Max memory allocated before optimizer step()", rank)
optimizer.step()
print_peak_memory("Max memory allocated after optimizer step()", rank)
print(f"params sum is: {sum(model.parameters()).sum()}")
def main():
world_size = 2
print("=== Using ZeroRedundancyOptimizer ===")
mp.spawn(example,
args=(world_size, True),
nprocs=world_size,
join=True)
print("=== Not Using ZeroRedundancyOptimizer ===")
mp.spawn(example,
args=(world_size, False),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()
下面是输出结果。当启用 ZeroRedundancyOptimizer
和 Adam
时,优化器 step()
的峰值内存消耗是普通 Adam
的一半。这符合我们的预期,因为我们在两个进程之间分片了 Adam
优化器状态。输出还显示,即使使用 ZeroRedundancyOptimizer
,模型参数在一次迭代后仍保持相同的值(使用和不使用 ZeroRedundancyOptimizer
的参数总和相同)。
=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875