使用张量并行 (TP) 进行大规模 Transformer 模型训练¶
Created On: Apr 19, 2024 | Last Updated: Aug 19, 2024 | Last Verified: Nov 05, 2024
作者: Wanchao Liang, Tianyu Liu
备注
在 github 中查看和编辑本教程。
本教程演示了如何使用张量并行和完全分片的数据并行在数百到数千个 GPU 上训练大型 Transformer 类模型。
前提条件:
安装了 PyTorch 2.3.0 或更高版本,并启用了 CUDA/Linux
张量并行如何工作?¶
张量并行 (TP) 最初是在 Megatron-LM 论文中提出的,它是一种高效的模型并行技术,用于训练大规模 Transformer 模型。本教程中提到的 序列并行 (SP) 是张量并行的一种变体,可以在 nn.LayerNorm
或 RMSNorm
的序列维度上进行分片,从而进一步节省训练期间的激活内存。随着模型变得更大,激活内存成为瓶颈,因此在张量并行训练中通常将序列并行应用于 LayerNorm
或 RMSNorm
层。
从高层次来看,PyTorch 张量并行的工作原理如下:
分片初始化
确定对每一层应用哪种
ParallelStyle
,并通过调用parallelize_module
对初始化的模块进行分片。并行化的模块将其模型参数交换为 DTensors,DTensor 将负责使用分片计算运行并行化模块。
运行时前向/后向
根据用户为每个
ParallelStyle
指定的输入/输出 DTensor 布局,运行适当的通信操作以转换输入/输出的 DTensor 布局(例如allreduce
、allgather
和reduce_scatter
)。对并行化的层运行分片计算以节省计算和内存(例如
nn.Linear
、nn.Embedding
)。
何时以及为什么应用张量并行¶
PyTorch 的完全分片的数据并行 (FSDP) 已经能够将模型训练扩展到特定数量的 GPU。然而,当进一步在模型规模和 GPU 数量方面扩展模型训练时,会出现许多额外的挑战,这可能需要结合张量并行与 FSDP:
当世界大小(即 GPU 数量)变得过大(超过 128/256 个 GPU)时,FSDP 的集合通信(例如
allgather
)开始被环形延迟所支配。通过在 FSDP 上实现 TP/SP,可以将 FSDP 的世界大小减少 8 倍,通过仅在主机间应用 FSDP,从而相应减少延迟成本。数据并行性达到限制,即由于收敛性和 GPU 内存限制,无法将全局批大小提高到 GPU 数量之上,张量/序列并行是唯一已知的方法,可以“大致估算”全局批大小并继续通过更多 GPU 扩展。这意味着模型规模和 GPU 数量都可以继续扩展。
对于某些类型的模型,当本地批量大小变得较小时,TP/SP 可以生成更优化浮点操作(FLOPS)的矩阵乘法形状。
那么,在预训练时,达到这些限制有多容易?截至目前,即使使用数千个 GPU,预训练一个有数十亿或万亿个标记的大型语言模型 (LLM) 也可能需要数月时间。
在大规模训练 LLM 时,总是会达到限制 1。例如,Llama 2 70B 在 2000 个 GPU 上训练 35 天,需要在 2000 规模上使用多维并行。
当 Transformer 模型变得更大(例如 Llama2 70B)时,也会迅速达到限制 2。即使本地
batch_size=1
,也不能单独使用 FSDP,原因是内存和收敛性限制。例如,Llama 2 的全局批量大小为 1K,因此仅数据并行性无法用于 2K GPU。
如何应用张量并行¶
PyTorch 张量并行 API 提供了一组模块级原语(ParallelStyle
)来配置模型每个单层的分片,包括:
ColwiseParallel
和RowwiseParallel
:以列或行方式分片nn.Linear
和nn.Embedding
。SequenceParallel
:对nn.LayerNorm
、nn.Dropout
、RMSNormPython
等进行分片计算。PrepareModuleInput
和PrepareModuleOutput
:通过适当的通信操作配置模块输入/输出的分片布局。
为了演示如何使用 PyTorch 原生张量并行 API,让我们来看看一个常见的 Transformer 模型。本教程中,我们使用最新的 Llama2 模型 作为参考 Transformer 模型实现,因为它也被社区广泛使用。
由于张量并行对单个张量在一组设备上进行分片,我们需要首先设置分布式环境(例如 NCCL 通信器)。张量并行是一种类似于 PyTorch DDP/FSDP 的单程序多数据 (SPMD) 分片算法,底层利用 PyTorch DTensor 执行分片。它还利用 DeviceMesh 抽象(底层管理 ProcessGroups)进行设备管理和分片。要了解如何利用 DeviceMesh 设置多维并行,请参考 本教程。张量并行通常在每个主机内工作,因此我们首先初始化一个连接主机内 8 个 GPU 的 DeviceMesh。
from torch.distributed.device_mesh import init_device_mesh
tp_mesh = init_device_mesh("cuda", (8,))
现在我们已经初始化了 DeviceMesh,让我们详细查看 Llama 2 模型架构并了解如何执行张量并行分片。这里我们重点关注核心 TransformerBlock
,其中 Transformer 模型堆叠多个相同的 TransformerBlock
来增强模型规模。
核心 TransformerBlock
包括一个 Attention
层和一个 FeedForward
层。我们先来看简化版 FeedForward
层。对于 FeedForward
层,它由三个线性层组成,其中执行了 SwiGLU 风格的 MLP,看其前向函数:
# forward in the FeedForward layer
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
它并行执行 w1
和 w3
的矩阵乘法,然后结合 w1/w3 的线性投影结果进行 w2
的矩阵乘法。这意味着我们可以利用张量并行论文中的思路以列分片方式分片 w1/w3 线性层,并以行分片方式分片 w2
线性层,以便在所有三个层结束时仅发生一次 allreduce
通信。通过 PyTorch 原生张量并行,我们可以像下面这样为 FeedForward
层简单地创建一个 parallelize_plan
:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"feed_foward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这就是我们如何使用 PyTorch 张量并行 API 配置 FeedForward
层分片的方式。请注意,用户只需指定如何分片单独的层,通信(例如 allreduce
)将在底层自动进行。
接下来是 Attention
层。它包括将输入投射到 q
/ k
/ v
的 wq
、wk
、wv
线性层,然后使用 wo
线性层执行注意力和输出投影。这里张量并行意图对 q/k/v 投影进行列分片,对 wo
线性投影进行行分片。因此,我们可以将 Attention 计划添加到我们刚刚起草的 tp_plan
中:
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
}
这几乎就是我们需要应用张量并行的 layer_tp_plan
,以用于 TransformerBlock
。然而,我们需要注意的是,在对线性层进行列分片时,线性层的输出会在张量的最后一个维度上分片,而行分片线性层直接接收在最后一个维度上分片的输入。如果在列分片线性层和行分片线性层之间还有其他张量操作(例如视图操作),我们需要调整相关的与形状有关的操作到分片形状。
对于 Llama 模型,注意力层中有一些与形状相关的视图操作。特别是对于 wq
、wk
和 wv
线性层的列分片并行,激活张量在 num_heads
维度上分片,因此我们需要调整 num_heads
到本地 num_heads
。
最后,我们需要调用 parallelize_module
API 使每个 TransformerBlock
的计划生效。底层会将 Attention
和 FeedForward
层中的模型参数分发到 DTensors,并为模型输入和输出(分别在每个模块之前和之后)注册通信钩子(如有必要):
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {...} # i.e. the plan we just generated
# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
现在我们已经详细制定了每个``TransformerBlock``的分片计划,通常在第一层会有一个``nn.Embedding``,最后会有一个``nn.Linear``投影层。在这些层中,用户可以选择对第一个``nn.Embedding``进行行切分或列切分,并通过指定适当的输入和输出布局,对最后的``nn.Linear``投影层进行列切分。以下是一个示例:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"output": ColwiseParallel(
output_layouts=Replicate(),
),
}
)
备注
如果要分割的模型太大,无法放入 CPU 内存,可以使用``meta``设备初始化(例如,首先在 meta 设备上初始化模型,然后分割层,再实现模型),或者在 Transformer 模型初始化期间逐层并行化``TransformerBlock``层。
将序列并行应用于``LayerNorm/RMSNorm``层¶
序列并行基于上面介绍的张量并行构建。与基本的张量并行相比,它不仅分割``Attention``模块和``FeedForward``模块中的张量,还将它们的模块输入和输出(即在前向传递中的激活和在反向传递中的梯度)保持分割,而不是复制。
在典型的``TransformerBlock``中,前向函数结合了规范层(LayerNorm``或``RMSNorm
)、注意力层、前馈层和残差连接。例如:
# forward in a TransformerBlock
def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out
在大多数使用场景中,激活(以及梯度)的形状通常是``[批量大小,序列长度,隐藏维度]``,位于``Attention``和``FeedForward``模块之外。在 DTensor 的语言中,序列并行使用``Shard(1)``布局对模块进行前向和后向激活计算。以下代码展示了如何将序列并行应用于``TransformerBlock``中的规范层:
首先我们导入序列并行所需的依赖项:
from torch.distributed.tensor.parallel import (
PrepareModuleInput,
SequenceParallel,
)
接下来我们调整``layer_tp_plan``以在``RMSNorm``层上启用序列并行:
layer_tp_plan = {
# Now the input and output of SequenceParallel has Shard(1) layouts,
# to represent the input/output tensors sharded on the sequence dimension
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
}
可以看到,我们现在使用``PrepareModuleInput``将 Attention 和 FeedForward 层的模块输入布局从``Shard(1)``修改为``Replicate()``,并将它们的输出布局标记为``Shard(1)``。就像张量并行性一样,我们只需要指定输入和输出的张量分片布局,层之间的通信将自动发生。
请注意,使用序列并行时,我们假设``TransformerBlock``的输入和输出始终在序列维度上分片,以便可以无缝连接多个``TransformerBlocks``。这可以通过显式指定起始``nn.Embedding``层的输出和最终``nn.Linear``投影层的输入为``Shard(1)``来实现。
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate()
),
}
)
应用损失并行¶
损失并行是一种相关技术,用于在计算损失函数时节省内存和通信,因为模型的输出通常非常大。在损失并行中,当模型输出在(通常是巨大的)词汇维度上分片时,可以高效地计算交叉熵损失,而无需将所有模型输出收集到每个单独的 GPU。这不仅显著减少了内存消耗,还通过减少通信开销和在并行中进行分片计算来提高训练速度。下图简要说明了损失并行如何通过分片计算避免将所有模型输出收集到每个 GPU。

图 2. 使用损失并行在单个 GPU 上进行交叉熵损失的前向计算。蓝色表示分片张量;绿色表示复制张量;黄色表示具有部分值的张量(待全局缩减)。黑色箭头是本地计算;红色箭头是 GPU 之间的功能集合。¶
在 PyTorch 的张量并行 API 中,可以通过上下文管理器``loss_parallel``启用损失并行,使用此管理器,一可以直接使用``torch.nn.functional.cross_entropy``或``torch.nn.CrossEntropyLoss``,而无需修改代码的其他部分。
要应用损失并行,模型预测通常的形状``[批量大小,序列长度,词汇表大小]``应该在词汇表维度上分片。这可以通过标记最后线性投影层输出的布局轻松完成:
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
# use DTensor as the output
use_local_output=False,
),
},
)
在上面的代码中,我们还将序列并行应用于输出之前的规范层。我们应用``use_local_output=False``以使输出保持为 DTensor,能够与``loss_parallel``上下文管理器一起工作。在此之后,可以简单地调用交叉熵损失函数,如下所示。请注意,反向计算也需要在上下文内进行。
import torch.nn.functional as F
from torch.distributed.tensor.parallel import loss_parallel
pred = model(input_ids)
with loss_parallel():
# assuming pred and labels are of the shape [batch, seq, vocab]
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
loss.backward()
将张量并行与完全分片数据并行结合使用¶
现在我们已经展示了如何将张量/序列并行应用于模型,让我们看一下张量并行和完全分片数据并行如何协同工作。由于张量并行性会导致阻止计算的通信,我们需要确保它在一个快速通信通道内运行,例如 NVLink。在实践中,我们通常在每个主机内应用张量并行,并在主机之间应用完全分片数据并行。

图 3. FSDP 和 TP 在单独的设备维度上工作,FSDP 通信发生在主机之间,TP 通信发生在主机内部。¶
这种二维并行模式可以通过二维设备网格轻松表达,我们只需将每个“子”设备网格传递给各自的并行 API:
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
mesh_2d = init_device_mesh("cuda", (8, 8))
tp_mesh = mesh_2d["tp"] # a submesh that connects intra-host devices
dp_mesh = mesh_2d["dp"] # a submesh that connects inter-host devices
model = Model(...)
tp_plan = {...}
# apply Tensor Parallel intra-host on tp_mesh
model_tp = parallelize_module(model, tp_mesh, tp_plan)
# apply FSDP inter-host on dp_mesh
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)
这使我们能够轻松地在每个主机内(主机内部)应用张量并行,并在主机间(主机之间)应用 FSDP,对 Llama 模型无代码更改。张量(模型)并行和数据并行技术结合在一起提供了继续增大模型规模的能力,并使用大量 GPU 进行高效训练。
结论¶
本教程演示了如何使用张量并行结合完全分片数据并行在数百到数千个 GPU 上训练大型 Transformer 类模型。它解释了如何将张量并行应用于模型的不同部分,并且对模型本身没有代码更改。张量并行是一种高效的大规模训练模型并行技术。
要查看本教程中解释的完整端到端代码示例,请参阅 PyTorch/examples 存储库中的`张量并行示例 <https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/fsdp_tp_example.py>`__。