使用分布式检查点 (DCP) 的异步保存¶
Created On: Jul 22, 2024 | Last Updated: Jul 22, 2024 | Last Verified: Nov 05, 2024
作者: Lucas Pasqualin, Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang
检查点通常是分布式训练任务关键路径中的瓶颈,随着模型和世界规模的增长,其成本也越来越高。一种优秀的策略是异步并行检查点。下面,我们扩展了 分布式检查点教程入门 中保存的示例,以展示如何非常轻松地与 torch.distributed.checkpoint.async_save
集成。
如何使用 DCP 并行生成检查点
优化性能的有效策略
PyTorch v2.4.0 或更高版本
异步检查点概述¶
在开始使用异步检查点之前,了解它与同步检查点之间的差异和限制很重要。具体来说:
- 内存需求 - 异步检查点通过将模型首先复制到内部CPU缓冲区来工作。
这非常有帮助,因为它确保模型和优化器权重在检查点过程中不会改变,但会将CPU内存提高至“checkpoint_size_per_rank X number_of_ranks”的倍数。此外,用户应注意了解其系统的内存限制。特别是,固定内存在使用“页面锁定”内存,这种内存相比于“可分页”内存更加稀缺。
- 检查点管理 - 由于检查点是异步的,用户需要自行管理同时运行的检查点。通常,用户可以
通过处理 async_save 返回的未来对象来采用自己的管理策略。对于大多数用户,我们建议将检查点限制为一次一个异步请求,以避免因每次请求带来的额外内存压力。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
checkpoint_future = None
for step in range(10):
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
# waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
if checkpoint_future is not None:
checkpoint_future.result()
state_dict = { "app": AppState(model, optimizer) }
checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running async checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
使用固定内存获得更高性能¶
如果上述优化仍然不足以满足性能需求,您可以通过使用额外优化来利用固定内存缓冲区对GPU模型进行检查点处理。特别地,这种优化针对异步检查点的主要开销,即将内存复制到检查点缓冲区。通过在检查点请求之间保持一个固定内存缓冲区,用户可以利用直接内存访问来加速复制。
备注
这种优化的主要缺点是在检查点步骤之间保持缓冲区的持久性。没有固定内存优化(如上所述),任何检查点缓冲区在检查点完成后都会被释放。使用固定内存实现时,此缓冲区在步骤之间保持存在,导致应用生命周期内持续承受同样的峰值内存压力。
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint import StorageWriter
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = FSDP(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# The storage writer defines our 'staging' strategy, where staging is considered the process of copying
# checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying
# into a persistent buffer with pinned memory enabled.
# Note: It's important that the writer persists in between checkpointing requests, since it maintains the
# pinned memory buffer.
writer = StorageWriter(cached_state_dict=True)
checkpoint_future = None
for step in range(10):
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
state_dict = { "app": AppState(model, optimizer) }
if checkpoint_future is not None:
# waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
checkpoint_future.result()
dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
结论¶
总之,我们学习了如何使用DCP的 async_save API 在关键训练路径之外生成检查点。我们还了解了使用此API引入的额外内存和并发开销,以及利用固定内存进一步加速的额外优化。