• Tutorials >
  • 全面分片数据并行(FSDP2)入门
Shortcuts

全面分片数据并行(FSDP2)入门

Created On: Mar 17, 2022 | Last Updated: May 16, 2025 | Last Verified: Nov 05, 2024

作者: Wei Feng, Will Constable, Yifan Mao

备注

编辑 从`PyTorch示例代码仓库 <https://github.com/pytorch/examples/tree/main/distributed/FSDP2>`_中查看本教程代码。FSDP1将被弃用,旧教程可以在`这里 <https://docs.pytorch.org/tutorials/intermediate/FSDP1_tutorial.html>`_找到。

FSDP2的工作原理

在`分布式数据并行 <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__ (DDP)训练中,每个任务拥有一个模型副本并处理一个批次的数据,最后使用全规约在任务之间同步梯度。

与DDP相比,FSDP通过对模型参数、梯度和优化器状态进行分片来减少GPU内存占用。这使得训练无法放入单个GPU中的模型变得可行。如下图所示:

  • 在前向和后向计算之外,参数是完全分片的

  • 在前向和后向计算之前,分片的参数被收集为未分片的参数

  • 在后向计算中,本地未分片梯度被减少分散为分片的梯度

  • 优化器使用分片的梯度更新分片的参数,从而生成分片的优化器状态

FSDP 工作流程

FSDP可以被认为是DDP的全规约的分解为减少分散和全收集操作

FSDP的全收集和减少分散

与`FSDP1 <https://docs.pytorch.org/docs/stable/fsdp.html>`_相比,FSDP2具有以下优点:

  • 将分片参数表示为`DTensor <https://docs.pytorch.org/docs/stable/distributed.tensor.html>`_在维度i上分片,允许轻松操作各个参数,无需通信的分片状态字典,以及更简单的元设备初始化流程。

  • 改进内存管理系统,通过避免使用``recordStream`` (文档)实现更低和可预测的GPU内存,并且无需任何CPU同步。

  • Offering a tensor subclass extension point to customize the all-gather, e.g. for float8 all-gather for float8 linears (doc), and NF4 for QLoRA (doc)

  • 混合冻结和非冻结参数可以在同一个通信组中,而无需使用额外的内存。

如何使用FSDP2

模型初始化

对子模块应用fully_shard:与DDP不同,我们应对子模块和根模型应用`fully_shard <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html>`_。在下面的Transformer示例中,我们首先对每个层应用``fully_shard``,然后是根模型

  • 在``layers[i]``的前向计算过程中,其余层被分片以减少内存占用

  • 在``fully_shard(model)``中,FSDP2排除``model.layers``的参数,并将剩余参数分类为一个参数组,以实现高效的全收集和减少分散

  • fully_shard``将分片后的模型移动到实际的训练设备(例如``cuda

命令: torchrun --nproc_per_node 2 train.py

from torch.distributed.fsdp import fully_shard, FSDPModule
model = Transformer()
for layer in model.layers:
    fully_shard(layer)
fully_shard(model)

assert isinstance(model, Transformer)
assert isinstance(model, FSDPModule)
print(model)
#  FSDPTransformer(
#    (tok_embeddings): Embedding(...)
#    ...
#    (layers): 3 x FSDPTransformerBlock(...)
#    (output): Linear(...)
#  )

我们可以使用``print(model)``检查嵌套包装情况。``FSDPTransformer``是`Transformer <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L100>`_和`FSDPModule <​https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule>`_的联合类。同样的事情发生在`FSDPTransformerBlock <https://github.com/pytorch/examples/blob/70922969e70218458d2a945bf86fd8cc967fc6ea/distributed/FSDP2/model.py#L76C7-L76C18>`_上。所有FSDP2公开API通过``FSDPModule``暴露。例如,用户可以调用``model.unshard()``手动控制全收集计划。详细信息参见下面的“显式预取”。

model.parameters()作为DTensor: ``fully_shard``将在任务之间分片参数,并将``model.parameters()``从普通的``torch.Tensor``转化为DTensor以表示分片参数。FSDP2默认在维度0上分片,因此DTensor的配置为`Shard(dim=0)`。假设有N个任务以及一个有N行的参数,则在分片后每个任务拥有该参数的一行。我们可以使用``param.to_local()``检查分片参数。

from torch.distributed.tensor import DTensor
for param in model.parameters():
    assert isinstance(param, DTensor)
    assert param.placements == (Shard(0),)
    # inspect sharded parameters with param.to_local()

optim = torch.optim.Adam(model.parameters(), lr=1e-2)

注意优化器是在应用``fully_shard``之后构建的。模型和优化器状态字典都以DTensor表示。

DTensor方便了优化器、梯度裁剪和检查点保存

  • ``torch.optim.Adam``和``torch.nn.utils.clip_grad_norm_``能够直接作用于DTensor参数。这使得代码在单设备和分布式训练之间保持一致

  • 我们可以使用DTensor和DCP API操作参数以获得完整的状态字典,更多细节请参见“状态字典”部分。对于分布式状态字典,我们可以保存/加载检查点 (文档) 而无需额外通信

通过预取进行前向/后向计算

命令: torchrun --nproc_per_node 2 train.py

for _ in range(epochs):
    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    loss = model(x).sum()
    loss.backward()
    optim.step()
    optim.zero_grad()

``fully_shard``注册前向/后向钩子,以便在计算前全收集参数,并在计算后重新分片参数。为了与计算重叠全收集,FSDP2提供了**隐式预取**,可以直接与上述训练循环一起运行,以及面向高级用户的**显式预取**以手动控制全收集计划。

隐式预取: CPU线程在第i层前发出第i层的全收集。全收集被排队到其自己的CUDA流,而第i层的计算发生在默认流中。对于非CPU受限的工作负载(例如大型批量的Transformer),第i+1层的全收集可以与第i层的计算重叠。隐式预取在后向计算中类似进行,只是全收集是在前向计算完成的反顺序中发出的。

FSDP隐式预取

我们建议用户从隐式预取开始,以了解开箱即用的性能。

显式预取: 用户可以使用`set_modules_to_forward_prefetch <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_modules_to_forward_prefetch>`_指定前向顺序,以及使用`set_modules_to_backward_prefetch <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_modules_to_backward_prefetch>`_指定后向顺序。如下代码所示,CPU线程在第i层时发出第i+1层和第i+2层的全收集

显式预取在以下情况效果良好:

CPU受限工作负载: 如果使用隐式预取,CPU线程在第i层的内核执行时发出第i+1层的全收集过慢。我们必须显式在第i层运行前发出第i+1层的全收集

预取超过2层: 隐式预取每次仅全收集下一层,以保持最低内存占用。显式预取可以同时全收集多层,可能通过增加内存提高性能。参见代码中的``layers_to_prefetch``

更早发出第一个全收集: 隐式预取发生在调用``model(x)``时,第一个全收集才开始暴露。我们可以显式调用`model.unshard() <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.unshard>`_更早发出第一个全收集

命令: torchrun --nproc_per_node 2 train.py --explicit-prefetching

num_to_forward_prefetch = 2
for i, layer in enumerate(model.layers):
    if i >= len(model.layers) - num_to_forward_prefetch:
        break
    layers_to_prefetch = [
        model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
    ]
    layer.set_modules_to_forward_prefetch(layers_to_prefetch)

num_to_backward_prefetch = 2
for i, layer in enumerate(model.layers):
    if i < num_to_backward_prefetch:
        continue
    layers_to_prefetch = [
        model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
    ]
    layer.set_modules_to_backward_prefetch(layers_to_prefetch)

for _ in range(epochs):
    # trigger 1st all-gather earlier
    # this overlaps all-gather with any computation before model(x)
    model.unshard()
    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    loss = model(x).sum()
    loss.backward()
    optim.step()
    optim.zero_grad()

启用混合精度

FSDP2提供了灵活的`混合精度策略 <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.MixedPrecisionPolicy>`_以加速训练。一个典型的用例是:

  • 将float32参数转换为bfloat16以进行前向/后向计算,参见``param_dtype=torch.bfloat16``

  • 将梯度上升转换为float32以进行减少分散以提高精度,参见``reduce_dtype=torch.float32``

与`torch.amp <https://docs.pytorch.org/docs/stable/amp.html>`_相比,FSDP2混合精度具有以下优势

  • 高效和灵活的参数转换: ``FSDPModule``中的所有参数在模块边界处(前向和后向之前和之后)会一起转换。我们可以为每个层设置不同的混合精度策略。例如,前几层可以是float32,而其余的层可以是bfloat16。

  • float32梯度减少(减少分散): 梯度可能从任务到任务变化很大。使用float32减少梯度对于数值计算可能是关键的。

命令: torchrun --nproc_per_node 2 train.py --mixed-precision

model = Transformer(model_args)
fsdp_kwargs = {
    "mp_policy": MixedPrecisionPolicy(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    )
}
for layer in model.layers:
    fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)

# sharded parameters are float32
for param in model.parameters():
    assert param.dtype == torch.float32

# unsharded parameters are bfloat16
model.unshard()
for param in model.parameters(recurse=False):
    assert param.dtype == torch.bfloat16
model.reshard()

# optimizer states are in float32
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

# training loop
# ...

使用DTensor进行梯度裁剪和优化器

命令: torchrun --nproc_per_node 2 train.py

# optim is constructed base on DTensor model parameters
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for _ in range(epochs):
    x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
    loss = model(x).sum()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
    optim.step()
    optim.zero_grad()

优化器是在对模型应用``fully_shard``之后初始化的,并持有对DTensor``model.parameters()``的引用。对于梯度裁剪,``torch.nn.utils.clip_grad_norm_``适用于DTensor参数。张量操作将在DTensor内正确分派,以在任务之间通信部分张量以保留单设备语义。

使用DTensor API的状态字典

我们展示了如何将完整的状态字典转换为 DTensor 状态字典进行加载,以及如何将其转换回完整状态字典进行保存。

命令: torchrun --nproc_per_node 2 train.py

  • 首次,为模型和优化器创建检查点。

  • 第二次,从之前的检查点加载以继续训练。

加载状态字典:我们在元设备下初始化模型,并调用 fully_shardmodel.parameters() 从普通 torch.Tensor 转换为 DTensor。在使用 torch.load 读取完整状态字典后,我们可以调用 distributed_tensor 将普通 torch.Tensor 转换为 DTensor,使用 model.state_dict() 中的相同放置方式和设备网格。最后,我们可以调用 model.load_state_dict 将 DTensor 状态字典加载到模型中。

from torch.distributed.tensor import distribute_tensor

# mmap=True reduces CPU memory usage
full_sd = torch.load(
    "checkpoints/model_state_dict.pt",
    mmap=True,
    weights_only=True,
    map_location='cpu',
)
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
    sharded_meta_param = meta_sharded_sd.get(param_name)
    sharded_tensor = distribute_tensor(
        full_tensor,
        sharded_meta_param.device_mesh,
        sharded_meta_param.placements,
    )
    sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# `assign=True` since we cannot call `copy_` on meta tensor
model.load_state_dict(sharded_sd, assign=True)

保存状态字典model.state_dict() 返回一个 DTensor 状态字典。我们可以通过调用 full_tensor() 将 DTensor 转换为普通 torch.Tensor。内部调用了跨等级的 all-gather 来获取未分片参数的普通 torch.Tensor。在 0 号等级,full_param.cpu() 将张量逐一卸载到 CPU,以避免 GPU 内存峰值。

sharded_sd = model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
    full_param = sharded_param.full_tensor()
    if torch.distributed.get_rank() == 0:
        cpu_state_dict[param_name] = full_param.cpu()
    else:
        del full_param
torch.save(cpu_state_dict, "checkpoints/model_state_dict.pt")

优化器状态字典的工作原理类似(代码)。用户可以自定义上述 DTensor 脚本以支持第三方检查点。

如果不需要自定义,我们可以直接使用 DCP APIs 来支持单节点和多节点训练。

使用 DCP APIs 的状态字典

命令torchrun --nproc_per_node 2 train.py --dcp-api

  • 首次,为模型和优化器创建检查点。

  • 第二次,从之前的检查点加载以继续训练。

加载状态字典:我们可以通过 set_model_state_dict 将完整状态字典加载到 FSDP2 模型中。通过设置 broadcast_from_rank0=True,我们可以仅在 0 号等级加载完整状态字典,以避免 CPU 内存过载。DCP 会将张量分片并广播到其他等级。

from torch.distributed.checkpoint.state_dict import set_model_state_dict
set_model_state_dict(
    model=model,
    model_state_dict=full_sd,
    options=StateDictOptions(
        full_state_dict=True,
        broadcast_from_rank0=True,
    ),
)

保存状态字典get_model_state_dict 加上 full_state_dict=Truecpu_offload=True 可以完成 all-gather 张量并将其卸载到 CPU。这与 DTensor APIs 的工作方式类似。

from torch.distributed.checkpoint.state_dict import get_model_state_dict
model_state_dict = get_model_state_dict(
    model=model,
    options=StateDictOptions(
        full_state_dict=True,
        cpu_offload=True,
    )
)
torch.save(model_state_dict, "model_state_dict.pt")

有关使用 set_optimizer_state_dictget_optimizer_state_dict 加载和保存优化器状态字典的更多信息,请参阅 pytorch/examples

FSDP1 到 FSDP2 的迁移指南

让我们以 FSDP 的一个使用示例及其等效的 fully_shard 使用示例为例。我们将突出主要差异,并建议迁移步骤。

原始 FSDP() 使用方式

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with torch.device("meta"):
    model = Transformer()
policy = ModuleWrapPolicy({TransformerBlock})
model = FSDP(model, auto_wrap_policy=policy)
def param_init_fn(module: nn.Module) -> None: ...
model = FSDP(model, auto_wrap_policy=policy, param_init_fn=param_init_fn)

新的 fully_shard() 使用方式

with torch.device("meta"):
    model = Transformer()
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
    assert tensor.device == torch.device("meta")


# Initialize the model after sharding
model.to_empty(device="cuda")
model.reset_parameters()

迁移步骤

  • 替换导入

  • 直接实现您的 ‘策略’ (对目标子层应用 fully_shard

  • fully_shard 包裹您的根模型,而不是使用 FSDP

  • 去掉 param_init_fn 并手动调用 model.reset_parameters()

  • 替换其他 FSDP1 的关键参数(详情见下)

分片策略

  • FULL_SHARD: reshard_after_forward=True

  • SHARD_GRAD_OP: reshard_after_forward=False

  • HYBRID_SHARD: reshard_after_forward=True,使用二维设备网格

  • _HYBRID_SHARD_ZERO2: reshard_after_forward=False,使用二维设备网格

CPU 卸载

  • CPUOffload.offload_params=False: offload_policy=None

  • CPUOffload.offload_params=True: offload_policy=CPUOffloadPolicy()

反向预取

  • BACKWARD_PRE: 总是使用

  • BACKWARD_POST: 不支持

混合精度

  • 因为 fully_shard 不分片缓冲区,所以省略了 buffer_dtype

  • fully_shard 的 cast_forward_inputs 映射到 FSDP1 中的 cast_forward_inputscast_root_forward_inputs

  • output_dtype 是一个新的配置用于 fully_shard

设备 ID: 从 device_mesh 的设备推断

sync_module_states=True/False: 已移至 DCP。用户可以使用 set_model_state_dictbroadcast_from_rank0=True 从 0 号等级广播状态字典

forward_prefetch: 可以手动控制预取

限制所有 all-gathers: 不再需要,因为 fully_shard 删除了 CPU 同步

使用原始参数: 始终使用原始参数(不再有平坦的参数)

no_sync(): set_requires_gradient_sync

忽略参数和忽略状态: ignored_params

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源