全面分片数据并行(FSDP2)入门¶
Created On: Mar 17, 2022 | Last Updated: May 16, 2025 | Last Verified: Nov 05, 2024
作者: Wei Feng, Will Constable, Yifan Mao
备注
从`PyTorch示例代码仓库 <https://github.com/pytorch/examples/tree/main/distributed/FSDP2>`_中查看本教程代码。FSDP1将被弃用,旧教程可以在`这里 <https://docs.pytorch.org/tutorials/intermediate/FSDP1_tutorial.html>`_找到。
FSDP2的工作原理¶
在`分布式数据并行 <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__ (DDP)训练中,每个任务拥有一个模型副本并处理一个批次的数据,最后使用全规约在任务之间同步梯度。
与DDP相比,FSDP通过对模型参数、梯度和优化器状态进行分片来减少GPU内存占用。这使得训练无法放入单个GPU中的模型变得可行。如下图所示:
在前向和后向计算之外,参数是完全分片的
在前向和后向计算之前,分片的参数被收集为未分片的参数
在后向计算中,本地未分片梯度被减少分散为分片的梯度
优化器使用分片的梯度更新分片的参数,从而生成分片的优化器状态
FSDP可以被认为是DDP的全规约的分解为减少分散和全收集操作
与`FSDP1 <https://docs.pytorch.org/docs/stable/fsdp.html>`_相比,FSDP2具有以下优点:
将分片参数表示为`DTensor <https://docs.pytorch.org/docs/stable/distributed.tensor.html>`_在维度i上分片,允许轻松操作各个参数,无需通信的分片状态字典,以及更简单的元设备初始化流程。
改进内存管理系统,通过避免使用``recordStream`` (文档)实现更低和可预测的GPU内存,并且无需任何CPU同步。
Offering a tensor subclass extension point to customize the all-gather, e.g. for float8 all-gather for float8 linears (doc), and NF4 for QLoRA (doc)
混合冻结和非冻结参数可以在同一个通信组中,而无需使用额外的内存。
如何使用FSDP2¶
模型初始化¶
对子模块应用fully_shard:与DDP不同,我们应对子模块和根模型应用`fully_shard <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html>`_。在下面的Transformer示例中,我们首先对每个层应用``fully_shard``,然后是根模型
在``layers[i]``的前向计算过程中,其余层被分片以减少内存占用
在``fully_shard(model)``中,FSDP2排除``model.layers``的参数,并将剩余参数分类为一个参数组,以实现高效的全收集和减少分散
fully_shard``将分片后的模型移动到实际的训练设备(例如``cuda
)
命令: torchrun --nproc_per_node 2 train.py
from torch.distributed.fsdp import fully_shard, FSDPModule
model = Transformer()
for layer in model.layers:
fully_shard(layer)
fully_shard(model)
assert isinstance(model, Transformer)
assert isinstance(model, FSDPModule)
print(model)
# FSDPTransformer(
# (tok_embeddings): Embedding(...)
# ...
# (layers): 3 x FSDPTransformerBlock(...)
# (output): Linear(...)
# )
我们可以使用``print(model)``检查嵌套包装情况。``FSDPTransformer``是`Transformer <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L100>`_和`FSDPModule <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_的联合类。同样的事情发生在`FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_上。所有FSDP2公开API通过``FSDPModule``暴露。例如,用户可以调用``model.unshard()``手动控制全收集计划。详细信息参见下面的“显式预取”。
model.parameters()作为DTensor: ``fully_shard``将在任务之间分片参数,并将``model.parameters()``从普通的``torch.Tensor``转化为DTensor以表示分片参数。FSDP2默认在维度0上分片,因此DTensor的配置为`Shard(dim=0)`。假设有N个任务以及一个有N行的参数,则在分片后每个任务拥有该参数的一行。我们可以使用``param.to_local()``检查分片参数。
from torch.distributed.tensor import DTensor
for param in model.parameters():
assert isinstance(param, DTensor)
assert param.placements == (Shard(0),)
# inspect sharded parameters with param.to_local()
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
注意优化器是在应用``fully_shard``之后构建的。模型和优化器状态字典都以DTensor表示。
DTensor方便了优化器、梯度裁剪和检查点保存
通过预取进行前向/后向计算¶
命令: torchrun --nproc_per_node 2 train.py
for _ in range(epochs):
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()
``fully_shard``注册前向/后向钩子,以便在计算前全收集参数,并在计算后重新分片参数。为了与计算重叠全收集,FSDP2提供了**隐式预取**,可以直接与上述训练循环一起运行,以及面向高级用户的**显式预取**以手动控制全收集计划。
隐式预取: CPU线程在第i层前发出第i层的全收集。全收集被排队到其自己的CUDA流,而第i层的计算发生在默认流中。对于非CPU受限的工作负载(例如大型批量的Transformer),第i+1层的全收集可以与第i层的计算重叠。隐式预取在后向计算中类似进行,只是全收集是在前向计算完成的反顺序中发出的。
我们建议用户从隐式预取开始,以了解开箱即用的性能。
显式预取: 用户可以使用`set_modules_to_forward_prefetch <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_modules_to_forward_prefetch>`_指定前向顺序,以及使用`set_modules_to_backward_prefetch <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_modules_to_backward_prefetch>`_指定后向顺序。如下代码所示,CPU线程在第i层时发出第i+1层和第i+2层的全收集
显式预取在以下情况效果良好:
CPU受限工作负载: 如果使用隐式预取,CPU线程在第i层的内核执行时发出第i+1层的全收集过慢。我们必须显式在第i层运行前发出第i+1层的全收集
预取超过2层: 隐式预取每次仅全收集下一层,以保持最低内存占用。显式预取可以同时全收集多层,可能通过增加内存提高性能。参见代码中的``layers_to_prefetch``
更早发出第一个全收集: 隐式预取发生在调用``model(x)``时,第一个全收集才开始暴露。我们可以显式调用`model.unshard() <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.unshard>`_更早发出第一个全收集
命令: torchrun --nproc_per_node 2 train.py --explicit-prefetching
num_to_forward_prefetch = 2
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
for _ in range(epochs):
# trigger 1st all-gather earlier
# this overlaps all-gather with any computation before model(x)
model.unshard()
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()
启用混合精度¶
FSDP2提供了灵活的`混合精度策略 <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.MixedPrecisionPolicy>`_以加速训练。一个典型的用例是:
将float32参数转换为bfloat16以进行前向/后向计算,参见``param_dtype=torch.bfloat16``
将梯度上升转换为float32以进行减少分散以提高精度,参见``reduce_dtype=torch.float32``
与`torch.amp <https://docs.pytorch.org/docs/stable/amp.html>`_相比,FSDP2混合精度具有以下优势
高效和灵活的参数转换: ``FSDPModule``中的所有参数在模块边界处(前向和后向之前和之后)会一起转换。我们可以为每个层设置不同的混合精度策略。例如,前几层可以是float32,而其余的层可以是bfloat16。
float32梯度减少(减少分散): 梯度可能从任务到任务变化很大。使用float32减少梯度对于数值计算可能是关键的。
命令: torchrun --nproc_per_node 2 train.py --mixed-precision
model = Transformer(model_args)
fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
}
for layer in model.layers:
fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
# sharded parameters are float32
for param in model.parameters():
assert param.dtype == torch.float32
# unsharded parameters are bfloat16
model.unshard()
for param in model.parameters(recurse=False):
assert param.dtype == torch.bfloat16
model.reshard()
# optimizer states are in float32
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
# training loop
# ...
使用DTensor进行梯度裁剪和优化器¶
命令: torchrun --nproc_per_node 2 train.py
# optim is constructed base on DTensor model parameters
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for _ in range(epochs):
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
loss = model(x).sum()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
optim.step()
optim.zero_grad()
优化器是在对模型应用``fully_shard``之后初始化的,并持有对DTensor``model.parameters()``的引用。对于梯度裁剪,``torch.nn.utils.clip_grad_norm_``适用于DTensor参数。张量操作将在DTensor内正确分派,以在任务之间通信部分张量以保留单设备语义。
使用DTensor API的状态字典¶
我们展示了如何将完整的状态字典转换为 DTensor 状态字典进行加载,以及如何将其转换回完整状态字典进行保存。
命令: torchrun --nproc_per_node 2 train.py
首次,为模型和优化器创建检查点。
第二次,从之前的检查点加载以继续训练。
加载状态字典:我们在元设备下初始化模型,并调用 fully_shard
将 model.parameters()
从普通 torch.Tensor
转换为 DTensor。在使用 torch.load 读取完整状态字典后,我们可以调用 distributed_tensor 将普通 torch.Tensor
转换为 DTensor,使用 model.state_dict()
中的相同放置方式和设备网格。最后,我们可以调用 model.load_state_dict 将 DTensor 状态字典加载到模型中。
from torch.distributed.tensor import distribute_tensor
# mmap=True reduces CPU memory usage
full_sd = torch.load(
"checkpoints/model_state_dict.pt",
mmap=True,
weights_only=True,
map_location='cpu',
)
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# `assign=True` since we cannot call `copy_` on meta tensor
model.load_state_dict(sharded_sd, assign=True)
保存状态字典:model.state_dict()
返回一个 DTensor 状态字典。我们可以通过调用 full_tensor() 将 DTensor 转换为普通 torch.Tensor
。内部调用了跨等级的 all-gather 来获取未分片参数的普通 torch.Tensor。在 0 号等级,full_param.cpu()
将张量逐一卸载到 CPU,以避免 GPU 内存峰值。
sharded_sd = model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
full_param = sharded_param.full_tensor()
if torch.distributed.get_rank() == 0:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
torch.save(cpu_state_dict, "checkpoints/model_state_dict.pt")
优化器状态字典的工作原理类似(代码)。用户可以自定义上述 DTensor 脚本以支持第三方检查点。
如果不需要自定义,我们可以直接使用 DCP APIs 来支持单节点和多节点训练。
使用 DCP APIs 的状态字典¶
命令:torchrun --nproc_per_node 2 train.py --dcp-api
首次,为模型和优化器创建检查点。
第二次,从之前的检查点加载以继续训练。
加载状态字典:我们可以通过 set_model_state_dict 将完整状态字典加载到 FSDP2 模型中。通过设置 broadcast_from_rank0=True
,我们可以仅在 0 号等级加载完整状态字典,以避免 CPU 内存过载。DCP 会将张量分片并广播到其他等级。
from torch.distributed.checkpoint.state_dict import set_model_state_dict
set_model_state_dict(
model=model,
model_state_dict=full_sd,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
保存状态字典:get_model_state_dict 加上 full_state_dict=True
和 cpu_offload=True
可以完成 all-gather 张量并将其卸载到 CPU。这与 DTensor APIs 的工作方式类似。
from torch.distributed.checkpoint.state_dict import get_model_state_dict
model_state_dict = get_model_state_dict(
model=model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=True,
)
)
torch.save(model_state_dict, "model_state_dict.pt")
有关使用 set_optimizer_state_dict 和 get_optimizer_state_dict 加载和保存优化器状态字典的更多信息,请参阅 pytorch/examples。
FSDP1 到 FSDP2 的迁移指南¶
让我们以 FSDP 的一个使用示例及其等效的 fully_shard 使用示例为例。我们将突出主要差异,并建议迁移步骤。
原始 FSDP() 使用方式
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with torch.device("meta"):
model = Transformer()
policy = ModuleWrapPolicy({TransformerBlock})
model = FSDP(model, auto_wrap_policy=policy)
def param_init_fn(module: nn.Module) -> None: ...
model = FSDP(model, auto_wrap_policy=policy, param_init_fn=param_init_fn)
新的 fully_shard() 使用方式
with torch.device("meta"):
model = Transformer()
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
assert tensor.device == torch.device("meta")
# Initialize the model after sharding
model.to_empty(device="cuda")
model.reset_parameters()
迁移步骤
替换导入
直接实现您的 ‘策略’ (对目标子层应用
fully_shard
)用
fully_shard
包裹您的根模型,而不是使用FSDP
去掉
param_init_fn
并手动调用model.reset_parameters()
替换其他 FSDP1 的关键参数(详情见下)
分片策略
FULL_SHARD:
reshard_after_forward=True
SHARD_GRAD_OP:
reshard_after_forward=False
HYBRID_SHARD:
reshard_after_forward=True
,使用二维设备网格_HYBRID_SHARD_ZERO2:
reshard_after_forward=False
,使用二维设备网格
CPU 卸载
CPUOffload.offload_params=False:
offload_policy=None
CPUOffload.offload_params=True:
offload_policy=CPUOffloadPolicy()
反向预取
BACKWARD_PRE: 总是使用
BACKWARD_POST: 不支持
混合精度
因为 fully_shard 不分片缓冲区,所以省略了
buffer_dtype
。fully_shard 的
cast_forward_inputs
映射到 FSDP1 中的cast_forward_inputs
和cast_root_forward_inputs
output_dtype
是一个新的配置用于 fully_shard
设备 ID: 从 device_mesh 的设备推断
sync_module_states=True/False: 已移至 DCP。用户可以使用 set_model_state_dict 和 broadcast_from_rank0=True
从 0 号等级广播状态字典
forward_prefetch: 可以手动控制预取
使用以下 API 控制自动预取:set_modules_to_forward_prefetch 和 set_modules_to_backward_prefetch
限制所有 all-gathers: 不再需要,因为 fully_shard
删除了 CPU 同步
使用原始参数: 始终使用原始参数(不再有平坦的参数)
no_sync(): set_requires_gradient_sync
忽略参数和忽略状态: ignored_params