• Tutorials >
  • 使用 Join 上下文管理器进行具有不均衡输入的分布式训练
Shortcuts

使用 Join 上下文管理器进行具有不均衡输入的分布式训练

Created On: Aug 04, 2021 | Last Updated: Jan 09, 2023 | Last Verified: Nov 05, 2024

作者: Andrew Gu

备注

|编辑|github 中查看和编辑此教程。

备注

Join 在 PyTorch 1.10 中作为原型功能引入。此 API 可能会有变更。

在本教程中,您将看到:

  • Join 上下文管理器的概述。

  • 一个使用 DistributedDataParallel 的上下文管理器示例。

  • 一个使用 DistributedDataParallelZeroRedundancyOptimizer 的上下文管理器示例。

  • 一个传递关键字参数给上下文管理器的示例。

  • 深入探讨 Join 上下文管理器的工作原理。

  • 一个示例展示如何使一个玩具类兼容上下文管理器。

什么是 Join

使用分布式数据并行入门 - 基础应用场景 中,您看到了使用 DistributedDataParallel 进行数据并行训练的一般框架。这会隐式地在每次反向传递中调度所有归约操作以同步各个设备的梯度。这类`集合通信 <https://pytorch.org/docs/stable/distributed.html>`__需要该流程组中的所有设备参与,因此如果某个设备拥有较少的输入,其他设备就会停滞或出错(取决于后端)。更一般而言,对于任何执行每轮同步集合通信的类,这个问题都会持续存在。

Join 是一个上下文管理器,用于围绕每设备的训练循环,以便在输入不均衡情况下进行训练。上下文管理器允许提前耗尽其输入的设备(即 提前加入)影子执行尚未加入的设备进行的集合通信。影子操作的方式由钩子指定。

使用 JoinDistributedDataParallel

PyTorch 的 DistributedDataParallel 可以直接与 Join 上下文管理器一起使用。以下是一个使用示例:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    with Join([model]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

这将生成如下输出(其中,来自设备 0 和设备 1 的 print() 顺序可能是任意的):

Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!

备注

DistributedDataParallel 在引入这个通用 Join 上下文管理器之前已经提供了自己的 join() 上下文管理器。在上面的示例中,使用 with Join([model]): 等同于使用 with model.join():。现有的 DistributedDataParallel.join() 的一个限制是它不允许多个参与类,例如同时使用 DistributedDataParallelZeroRedundancyOptimizer

使用 Join,结合 DistributedDataParallelZeroRedundancyOptimizer

Join 上下文管理器不仅可以与单个类一起工作,还可以与多个类一起工作。PyTorch 的 ZeroRedundancyOptimizer 也与上下文管理器兼容。因此,在这里,我们研究如何修改以前的示例以同时使用 DistributedDataParallelZeroRedundancyOptimizer

from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    optim = ZeRO(model.parameters(), Adam, lr=0.01)
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0
    # Pass both `model` and `optim` into `Join()`
    with Join([model, optim]):
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
            optim.step()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

这将生成与之前相同的输出。显著的变化是将 ZeroRedundancyOptimizer 实例也传递给 Join()

传递关键字参数

类可以提供关键字参数,在运行时修改其在上下文管理器中的行为。例如,DistributedDataParallel 提供了一个参数 divide_by_initial_world_size,它决定梯度是除以初始设备数还是除以有效设备数(即未加入的设备数)。这样的关键字参数可以直接传递到上下文管理器中。

with Join([model, optim], divide_by_initial_world_size=False):
    for input in inputs:
        ...

警告

传递到上下文管理器中的关键字参数在所有参与类之间共享。这不应成为限制,因为我们不预期在同一参数的不同设置上需要不同的 Joinable 的情况。然而,这是需要注意的一点。

Join 的工作原理

现在我们已经看到了一些如何使用 Join 上下文管理器的初步示例,让我们更深入地探讨它的工作原理。这将提供对其全面功能的更深入了解,并准备好使您自己的自定义类兼容。接下来,我们将介绍 Join 类以及支持类 JoinableJoinHook

Joinable

首先,与 Join 上下文管理器兼容的类必须继承抽象基类 Joinable。特别是,一个 Joinable 必须实现:

  • join_hook(self, **kwargs) -> JoinHook

这会返回``Joinable``的``JoinHook``实例,以确定参与训练的进程如何跟踪``Joinable``在每次迭代中进行的广播通信。

  • join_device(self) -> torch.device

这会返回一个设备供``Join``上下文管理器用于执行广播通信,例如``torch.device(“cuda:0”)``或``torch.device(“cpu”)``。

  • join_process_group(self) -> ProcessGroup

这会返回一个进程组供``Join``上下文管理器用于执行广播通信。

特别是,``join_device``和``join_process_group``是必需属性,以确保上下文管理器能够在已加入和未加入的进程之间计划广播通信。一种用途是使用全归约在每次迭代中计算未加入进程的数量。另一种用途是实现``throw_on_early_termination=True``所需的机制,稍后我们会对此进行解释。

``DistributedDataParallel``和``ZeroRedundancyOptimizer``已经继承了``Joinable``并实现了上述方法,这就是为什么我们可以直接在之前的示例中使用它们。

``Joinable``类应确保调用``Joinable``构造函数,因为它会初始化一个``JoinConfig``实例,用于上下文管理器内部确保正确性。这将作为字段``_join_config``保存于每个``Joinable``中。

JoinHook

接下来,让我们详细说明``JoinHook``类。一个``JoinHook``为上下文管理器提供了两个入口点:

  • main_hook(self) -> None

在仍有尚未加入的进程存在的情况下,此钩子会被每个已加入的进程重复调用,目的是跟踪``Joinable``在每次训练迭代中执行的广播通信(例如一次前向传播、后向传播和优化步骤)。

  • post_hook(self, is_last_joiner: bool) -> None

此钩子会在所有进程已加入之后被调用。它会传递一个额外的``bool``参数``is_last_joiner``,指示该进程是否是最后一个加入的进程之一。此参数可能对同步有帮助。

为了给出这些钩子的具体示例,所提供的``ZeroRedundancyOptimizer``主钩子正常执行优化步骤,因为已加入的进程仍然负责更新和同步其参数的分片,而所提供的``DistributedDataParallel``后钩子从最后加入的进程之一广播最终更新的模型,以确保所有进程中的模型一致。

Join

最后,让我们检查这些如何适应到``Join``类本身。

  • __init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)

如我们在之前的示例中所见,构造函数接受参与训练循环的``Joinable``列表。这些应该是每次迭代中执行广播通信的类。

enable``是一个``bool,如果知道不会有不平衡输入,可以将其设置为``False``,此时上下文管理器变得无效,类似于``contextlib.nullcontext()``。这也可能在参与的``Joinable``中禁用与加入相关的计算。

throw_on_early_termination``是一个``bool,可以设置为``True``以使每个进程在检测到输入不平衡时立刻抛出异常。这对于不符合上下文管理器要求的情况非常有用,通常在有多个类的广播通信可能任意交错时,例如使用包含``SyncBatchNorm``层的模型与``DistributedDataParallel``一起使用。在这种情况下,此参数应设置为``True``,以便应用逻辑可以捕获异常并决定如何继续。

  • 核心逻辑发生在``__exit__()``方法中,该方法在仍存在未加入的进程时循环调用每个``Joinable``的主钩子,并且在所有进程已加入之后调用它们的后钩子。无论是主钩子还是后钩子,都会按照传递的``Joinable``列表中的顺序进行迭代。

  • 上下文管理器需要来自未加入进程的心跳。因此,每个``Joinable``类应在每次迭代的广播通信之前调用``Join.notify_join_context()``。上下文管理器会确保只有传递进来的第一个``Joinable``实际发送心跳。

警告

如上所述关于``throw_on_early_termination``,``Join``上下文管理器与某些类组合不兼容。``Joinable``的``JoinHook``必须是可序列化的,因为每个钩子都需要在执行下一个钩子之前完全执行完毕。换句话说,两个钩子不能重叠。此外,目前主钩子和后钩子都按照相同的确定性顺序进行迭代。如果这似乎是一个主要限制,我们可能会修改API以允许自定义排序。

使一个简单类兼容``Join``

由于上一节介绍了几个概念,让我们通过一个简单示例来实践它们。这里,我们将实现一个类,该类统计所有进程在首次加入之前看到的输入数量。这应该提供如何使自己的类兼容``Join``上下文管理器的基本思路。

具体来说,以下代码会让每个进程打印出:(1)在加入之前所有进程看到的输入数量,以及(2)所有进程看到的输入总数。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

class CounterJoinHook(JoinHook):
    r"""
    Join hook for :class:`Counter`.

    Arguments:
        counter (Counter): the :class:`Counter` object using this hook.
        sync_max_count (bool): whether to sync the max count once all ranks
            join.
    """
    def __init__(
        self,
        counter,
        sync_max_count
    ):
        self.counter = counter
        self.sync_max_count = sync_max_count

    def main_hook(self):
        r"""
        Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
        """
        t = torch.zeros(1, device=self.counter.device)
        dist.all_reduce(t)

    def post_hook(self, is_last_joiner: bool):
        r"""
        Synchronizes the max count across all :class:`Counter` s if
        ``sync_max_count=True``.
        """
        if not self.sync_max_count:
            return
        rank = dist.get_rank(self.counter.process_group)
        common_rank = self.counter.find_common_rank(rank, is_last_joiner)
        if rank == common_rank:
            self.counter.max_count = self.counter.count.detach().clone()
        dist.broadcast(self.counter.max_count, src=common_rank)

class Counter(Joinable):
    r"""
    Example :class:`Joinable` that counts the number of training iterations
    that it participates in.
    """
    def __init__(self, device, process_group):
        super(Counter, self).__init__()
        self.device = device
        self.process_group = process_group
        self.count = torch.tensor([0], device=device).float()
        self.max_count = torch.tensor([0], device=device).float()

    def __call__(self):
        r"""
        Counts the number of inputs processed on this iteration by all ranks
        by all-reducing a dim-1 one tensor; increments its own internal count.
        """
        Join.notify_join_context(self)
        t = torch.ones(1, device=self.device).float()
        dist.all_reduce(t)
        self.count += t

    def join_hook(self, **kwargs) -> JoinHook:
        r"""
        Return a join hook that shadows the all-reduce in :meth:`__call__`.

        This join hook supports the following keyword arguments:
            sync_max_count (bool, optional): whether to synchronize the maximum
                count across all ranks once all ranks join; default is ``False``.
        """
        sync_max_count = kwargs.get("sync_max_count", False)
        return CounterJoinHook(self, sync_max_count)

    @property
    def join_device(self) -> torch.device:
        return self.device

    @property
    def join_process_group(self):
        return self.process_group

    def find_common_rank(self, rank, to_consider):
        r"""
        Returns the max rank of the ones to consider over the process group.
        """
        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
        common_rank = common_rank.item()
        return common_rank

def worker(rank):
    assert torch.cuda.device_count() >= WORLD_SIZE
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    with Join([counter], sync_max_count=True):
        for _ in inputs:
            counter()

    print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
    print(f"{int(counter.max_count.item())} inputs processed across all ranks!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

由于进程0看到5个输入,进程1看到6个输入,这会产生如下输出:

10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!

一些需要注意的关键点:

  • 一个``Counter``实例每次迭代执行一次全归约,因此主钩子也执行一次全归约以跟踪它。

  • Counter``类在其``__call__()``方法的开始处调用``Join.notify_join_context(),因为这是它进行每次迭代广播通信(即其全归约)之前的地方。

  • ``is_last_joiner``参数用于在后钩子中确定广播源。

  • 我们将``sync_max_count``关键字参数传递给上下文管理器,然后转发到``Counter``的加入钩子。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源