• Tutorials >
  • 基于全分片数据并行 (FSDP) 的高级模型训练
Shortcuts

基于全分片数据并行 (FSDP) 的高级模型训练

Created On: Oct 31, 2024 | Last Updated: Oct 31, 2024 | Last Verified: Nov 05, 2024

作者Hamid ShojanazeriLess WrightRohan VarmaYanli Zhao

What you will learn
  • PyTorch 的完全分片数据并行模块:一个用于在

数据并行工作者之间分片模块参数的封装。

Prerequisites
  • PyTorch 1.12 或更高版本

  • 阅读 FSDP API 的相关内容。

本教程介绍了 PyTorch 1.12 版本中完全分片数据并行(FSDP)的一些高级功能。要了解 FSDP 的基础知识,请参阅 FSDP 入门教程

在本教程中,我们以实际示例,通过使用 FSDP 对 HuggingFace (HF) T5 模型进行文本摘要微调。

示例使用了 Wikihow 数据集,为了简单起见,我们将在一台节点上展示训练过程,该节点是具有 8 个 A100 GPU 的 P4dn 实例。目前,我们已有多个关于多节点集群上大规模 FSDP 训练的博客文章 ((链接1), (链接2)) 和一篇 论文

FSDP 是一个面向生产环境的软件包,着重于易用性、性能和长期支持。FSDP 的主要优势之一是减少每个 GPU 的内存占用。这使得可以用较少的总内存训练更大的模型,并利用计算和通信的重叠来高效训练模型。减少的内存压力可以用于训练更大的模型或增加批量大小,从而可能提高整体训练吞吐量。您可以在 此处 阅读有关 PyTorch FSDP 的更多信息。

本教程中的 FSDP 功能

  • Transformer 自动封装策略

  • 混合精度

  • 在设备上初始化 FSDP 模型

  • 分片策略

  • 反向预取

  • 通过流式传输到 CPU 保存模型检查点

FSDP 的工作原理回顾

从整体来看,FDSP 的工作流程如下:

在构造函数中

  • 分片模型参数,每个排名仅保留自己的分片。

在前向传递中

  • 运行 all_gather 从所有排名中收集所有分片以恢复完整参数,用于该 FSDP 单元并运行前向计算

  • 丢弃其刚刚收集的非拥有分片以释放内存

在反向传递中

  • 运行 all_gather 从所有排名中收集所有分片以恢复该 FSDP 单元中的完整参数并运行反向计算

  • 丢弃非拥有参数以释放内存。

  • 运行 reduce_scatter 同步梯度。

HF T5 微调

HF T5 预训练模型有四种不同的大小,从带有 6000 万参数的小型模型到带有 110 亿参数的 XXL 模型。在本教程中,我们利用 FSDP 对 WikiHow 数据集中的 T5 3B 模型进行文本摘要微调。本教程的主要重点是突出 FSDP 中可用的不同功能,这些功能对训练规模超过 3B 参数的大型模型有帮助。此外,我们还涵盖了基于 Transformer 模型的特定功能。本教程的代码可在 Pytorch 示例 中找到。

设置

1.1 安装最新的 PyTorch

pip3 install torch torchvision torchaudio

1.2 数据集设置

请创建一个 data 文件夹,从 wikihowAll.csvwikihowSep.cs 下载 WikiHow 数据集,并将其放置在 data 文件夹中。我们将使用 summarization_dataset 中的 wikihow 数据集。

接下来,我们将以下代码片段添加到 Python 脚本“T5_training.py”中。

备注

本教程的完整源码可在 PyTorch 示例 中找到。

1.3 导入必要的软件包:

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
import functools
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from transformers.models.t5.modeling_t5 import T5Block

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
 checkpoint_wrapper,
 CheckpointImpl,
 apply_activation_checkpointing_wrapper)

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
from functools import partial
from torch.utils.data import DataLoader
from pathlib import Path
from summarization_dataset import *
from transformers.models.t5.modeling_t5 import T5Block
from typing import Type
import time
import tqdm
from datetime import datetime

1.4 分布式训练设置。这里我们使用两个辅助函数分别用于初始化分布式训练的进程,以及在训练完成后进行清理。在本教程中,我们将使用 torch elatic,通过 torchrun 来自动设置 worker 的 RANKWORLD_SIZE

def setup():
    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

2.1 设置 HuggingFace T5 模型:

def setup_model(model_name):
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer =  T5Tokenizer.from_pretrained(model_name)
    return model, tokenizer

同时我们还添加了一些用于日期和格式化内存指标的辅助函数。

def get_date_of_run():
    """create date and time for file save uniqueness
    example: 2022-05-07-08:31:12_PM'
    """
    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
    print(f"--> current date and time of run = {date_of_run}")
    return date_of_run

def format_metrics_to_gb(item):
    """quick function to format numbers to gigabyte and round to 4 digit precision"""
    metric_num = item / g_gigabyte
    metric_num = round(metric_num, ndigits=4)
    return metric_num

2.2 定义一个训练函数:

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(2).to(local_rank)

    if sampler:
        sampler.set_epoch(epoch)
    if rank==0:
        inner_pbar = tqdm.tqdm(
            range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
        )
    for batch in train_loader:
        for key in batch.keys():
            batch[key] = batch[key].to(local_rank)
        optimizer.zero_grad()
        output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        fsdp_loss[0] += loss.item()
        fsdp_loss[1] += len(batch)
        if rank==0:
            inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    train_accuracy = fsdp_loss[0] / fsdp_loss[1]


    if rank == 0:
        inner_pbar.close()
        print(
                f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
            )
    return train_accuracy

2.3 定义一个验证函数:

def validation(model, rank, world_size, val_loader):
    model.eval()
    correct = 0
    local_rank = int(os.environ['LOCAL_RANK'])
    fsdp_loss = torch.zeros(3).to(local_rank)
    if rank == 0:
        inner_pbar = tqdm.tqdm(
            range(len(val_loader)), colour="green", desc="Validation Epoch"
        )
    with torch.no_grad():
        for batch in val_loader:
            for key in batch.keys():
                batch[key] = batch[key].to(local_rank)
            output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
            fsdp_loss[0] += output["loss"].item()  # sum up batch loss
            fsdp_loss[1] += len(batch)

            if rank==0:
                inner_pbar.update(1)

    dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
    val_loss = fsdp_loss[0] / fsdp_loss[1]
    if rank == 0:
        inner_pbar.close()
        print(f"Validation Loss: {val_loss:.4f}")
    return val_loss

2.4 定义一个分布式训练函数,包装 FSDP 中的模型:

def fsdp_main(args):

    model, tokenizer = setup_model("t5-base")

    local_rank = int(os.environ['LOCAL_RANK'])
    rank = int(os.environ['RANK'])
    world_size = int(os.environ['WORLD_SIZE'])


    dataset = load_dataset('wikihow', 'all', data_dir='data/')
    print(dataset.keys())
    print("Size of train dataset: ", dataset['train'].shape)
    print("Size of Validation dataset: ", dataset['validation'].shape)


    #wikihow(tokenizer, type_path, num_samples, input_length, output_length, print_text=False)
    train_dataset = wikihow(tokenizer, 'train', 1500, 512, 150, False)
    val_dataset = wikihow(tokenizer, 'validation', 300, 512, 150, False)

    sampler1 = DistributedSampler(train_dataset, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(val_dataset, rank=rank, num_replicas=world_size)

    setup()


    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)

    t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
    sharding_strategy: ShardingStrategy = ShardingStrategy.SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3
    torch.cuda.set_device(local_rank)


    #init_start_event = torch.cuda.Event(enable_timing=True)
    #init_end_event = torch.cuda.Event(enable_timing=True)

    #init_start_event.record()

    bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
    )

    if bf16_ready:
        mp_policy = bfSixteen
    else:
        mp_policy = None # defaults to fp32

    # model is on CPU before input to FSDP
    model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=mp_policy,
        #sharding_strategy=sharding_strategy,
        device_id=torch.cuda.current_device())

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    best_val_loss = float("inf")
    curr_val_loss = float("inf")
    file_save_name = "T5-model-"

    if rank == 0:
        time_of_run = get_date_of_run()
        dur = []
        train_acc_tracking = []
        val_acc_tracking = []
        training_start_time = time.time()

    if rank == 0 and args.track_memory:
        mem_alloc_tracker = []
        mem_reserved_tracker = []

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_accuracy = train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        if args.run_validation:
            curr_val_loss = validation(model, rank, world_size, val_loader)
        scheduler.step()

        if rank == 0:

            print(f"--> epoch {epoch} completed...entering save and stats zone")

            dur.append(time.time() - t0)
            train_acc_tracking.append(train_accuracy.item())

            if args.run_validation:
                val_acc_tracking.append(curr_val_loss.item())

            if args.track_memory:
                mem_alloc_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_allocated())
                )
                mem_reserved_tracker.append(
                    format_metrics_to_gb(torch.cuda.memory_reserved())
                )
            print(f"completed save and stats zone...")

        if args.save_model and curr_val_loss < best_val_loss:

            # save
            if rank == 0:
                print(f"--> entering save model state")

            save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(
                model, StateDictType.FULL_STATE_DICT, save_policy
            ):
                cpu_state = model.state_dict()
            #print(f"saving process: rank {rank}  done w state_dict")


            if rank == 0:
                print(f"--> saving model ...")
                currEpoch = (
                    "-" + str(epoch) + "-" + str(round(curr_val_loss.item(), 4)) + ".pt"
                )
                print(f"--> attempting to save model prefix {currEpoch}")
                save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
                print(f"--> saving as model name {save_name}")

                torch.save(cpu_state, save_name)

        if curr_val_loss < best_val_loss:

            best_val_loss = curr_val_loss
            if rank==0:
                print(f"-->>>> New Val Loss Record: {best_val_loss}")

    dist.barrier()
    cleanup()

2.5 解析参数并设置主函数:

if __name__ == '__main__':
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch T5 FSDP Example')
    parser.add_argument('--batch-size', type=int, default=4, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=4, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=2, metavar='N',
                        help='number of epochs to train (default: 3)')
    parser.add_argument('--lr', type=float, default=.002, metavar='LR',
                        help='learning rate (default: .002)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--track_memory', action='store_false', default=True,
                        help='track the gpu memory')
    parser.add_argument('--run_validation', action='store_false', default=True,
                        help='running the validation')
    parser.add_argument('--save-model', action='store_false', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    fsdp_main(args)

使用 torchrun 运行训练:

torchrun --nnodes 1 --nproc_per_node 4  T5_training.py

Transformer 封装策略

正如 先前教程 所讨论的,auto_wrap_policy 是 FSDP 的一项功能,它使得自动分片给定模型并将模型、优化器和梯度分片放入不同的 FSDP 单元变得容易。

对于一些架构,如 Transformer 编码器-解码器,模型的某些部分(如嵌入表)同时与编码器和解码器共享。在这种情况下,我们需要将嵌入表放置在外部 FSDP 单元中,以便从编码器和解码器访问它。此外,通过注册 Transformer 的层类,可以使分片计划更具通信效率。在 PyTorch 1.12 中,FSDP 增加了对此功能的支持,现在我们有了 Transformer 的封装策略。

可以按如下方式创建,其中 T5Block 表示 T5 Transformer 层类(包含 MHSA 和 FFN)。

t5_auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={
            T5Block,
        },
    )
torch.cuda.set_device(local_rank)


model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy)

要查看包装的模型,您可以轻松打印模型并可视化检查分片和 FSDP 单元。

混合精度

FSDP 支持灵活的混合精度训练,允许使用任意减少精度类型(如 fp16 或 bfloat16)。目前,BFloat16 仅适用于 Ampere GPU,因此在使用前需确认其原生支持。例如,在 V100 GPU 上,尽管 BFloat16 仍可运行,但由于非原生运行,可能导致显著的性能下降。

要检查 BFloat16 是否原生支持,您可使用以下代码:

bf16_ready = (
    torch.version.cuda
    and torch.cuda.is_bf16_supported()
    and LooseVersion(torch.version.cuda) >= "11.0"
    and dist.is_nccl_available()
    and nccl.version() >= (2, 10)
)

混合精度的一个优势是可以对参数、梯度和缓冲区的不同精度级别进行细粒度控制,具体如下:

fpSixteen = MixedPrecision(
    param_dtype=torch.float16,
    # Gradient communication precision.
    reduce_dtype=torch.float16,
    # Buffer precision.
    buffer_dtype=torch.float16,
)

bfSixteen = MixedPrecision(
    param_dtype=torch.bfloat16,
    # Gradient communication precision.
    reduce_dtype=torch.bfloat16,
    # Buffer precision.
    buffer_dtype=torch.bfloat16,
)

fp32_policy = MixedPrecision(
    param_dtype=torch.float32,
    # Gradient communication precision.
    reduce_dtype=torch.float32,
    # Buffer precision.
    buffer_dtype=torch.float32,
)

注意,如果未指定某种类型(参数、梯度归约、缓冲区),它们将不会进行精度转换。

此灵活性允许用户进行精细控制,例如仅设置梯度通信以较低精度运行,而所有参数/缓冲区计算以全精度完成。这在节点间通信是主要瓶颈而参数/缓冲区必须全精度以避免准确性问题的情况下可能很有用。这可以通过以下策略完成:

grad_bf16 = MixedPrecision(reduce_dtype=torch.bfloat16)

在步骤 2.4 中,我们仅仅将相关的混合精度策略添加到 FSDP 包装器中:

model = FSDP(model,
       auto_wrap_policy=t5_auto_wrap_policy,
       mixed_precision=bfSixteen)

在我们的实验中,使用 BFloat16 进行了训练,观察到速度提高最多达 4 倍,并且内存减少大约 30%,这些内存可用于增加批量大小。

在设备上初始化 FSDP 模型

在 1.12 中,FSDP 支持 device_id 参数,用于在 device_id 指定的设备上初始化输入的 CPU 模块。当整个模型无法放入单个 GPU,但能装入主机的 CPU 内存时,此功能很有用。当指定 device_id 时,FSDP 将逐个 FSDP 单元将模型移至指定设备,避免 GPU 内存不足问题,同时初始化速度比基于 CPU 的方式快几倍:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device())

分片策略

默认情况下,FSDP 分片策略设置为完全分片模型参数、梯度和优化器状态(也称为 Zero3 分片)。如果您有兴趣使用 Zero2 分片策略(仅分片优化器状态和梯度),FSDP 可以通过将分片策略从”ShardingStrategy.FULL_SHARD”改为”ShardingStrategy.SHARD_GRAD_OP”传递给 FSDP 初始化来支持该功能,如下所示:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # ZERO2)

这将减少 FSDP 的通信开销,在此情况下,它会在前向和反向传递中保持完整参数。

这可以在反向传递期间节省一次 all_gather,因此通信更少,但会以较高的内存占用为代价。注意,在反向传递的末尾会释放所有模型参数,而在下一个前向传递中会再次发生 all_gather。

反向预取

反向预取设置控制下一个 FSDP 单元的参数何时开始请求。通过将其设置为 BACKWARD_PRE,可以在当前单元的计算开始前提前请求并接收下一个 FSDP 单元的参数。这种通信和梯度计算的重叠可以略微提高训练速度,但以稍高的内存消耗为代价。在 2.4 中可以通过以下方式在 FSDP 包装器中利用它:

torch.cuda.set_device(local_rank)

 model = FSDP(model,
        auto_wrap_policy=t5_auto_wrap_policy,
        mixed_precision=bfSixteen,
        device_id=torch.cuda.current_device(),
        backward_prefetch = BackwardPrefetch.BACKWARD_PRE)

backward_prefetch 有两种模式,BACKWARD_PREBACKWARD_POSTBACKWARD_POST 表示在完成当前 FSDP 单元处理之前不会请求下一个 FSDP 单元的参数,从而将内存开销最小化。在某些情况下,使用 BACKWARD_PRE 可以将模型训练速度提高 2% 到 10%,对于更大的模型甚至能实现更高的速度提升。

通过流式传输到 Rank0 CPU 保存模型检查点

要使用 FULL_STATE_DICT 进行保存(以与本地模型相同的方式保存模型),PyTorch 1.12 提供了一些实用工具来支持保存更大的模型。

首先,可以指定 FullStateDictConfig,以便将 state_dict 仅填充到 0 号排名并卸载到 CPU。

使用此配置时,FSDP将收集模型参数,仅在0号任务中将其逐一卸载到CPU。当最终保存state_dict时,它将仅在0号任务中填充,并包含CPU张量。这可以避免对于比单个GPU内存更大的模型可能出现的内存不足问题,并允许用户对模型大小接近用户机器上可用的CPU RAM的模型进行检查点。

可以如下运行此功能:

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
            model, StateDictType.FULL_STATE_DICT, save_policy
        ):
            cpu_state = model.state_dict()
if rank == 0:
 save_name = file_save_name + "-" + time_of_run + "-" + currEpoch
 torch.save(cpu_state, save_name)

摘要

在本教程中,我们介绍了PyTorch 1.12中FSDP的许多新功能,并使用HF T5作为运行示例。对于Transformer模型,使用合适的包装策略,结合混合精度和后向预取,可以加速训练。此外,诸如在设备上初始化模型以及通过流式传输到CPU来保存检查点等功能可以帮助避免处理大模型时出现内存不足错误。

我们正在积极工作,为下一版本的FSDP添加新功能。如果您有反馈、功能请求、问题或在使用FSDP时遇到问题,请随时通过在`PyTorch GitHub存储库 <https://github.com/pytorch/pytorch>`__中打开问题与我们联系。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源