• Tutorials >
  • 使用``torchrun``进行故障容错分布式训练
Shortcuts

简介 || 什么是DDP || 单节点多GPU训练 || 故障容错 || 多节点训练 || minGPT训练

使用``torchrun``进行故障容错分布式训练

Created On: Sep 27, 2022 | Last Updated: Nov 12, 2024 | Last Verified: Nov 05, 2024

作者: Suraj Subramanian

What you will learn
  • 使用``torchrun``启动多GPU训练作业

  • 保存和加载训练作业的快照

  • 为优雅重启结构化训练脚本

GitHub 上查看此教程使用的代码

Prerequisites
  • DDP的高级 概述

  • 熟悉 DDP代码

  • 具有多GPU的机器(本教程使用AWS p3.8xlarge实例)

  • 安装CUDA版本的PyTorch 安装

根据以下视频或在 YouTube 上进行学习。

在分布式训练中,单个进程的故障可能会破坏整个训练作业。由于故障的可能性在这里更高,使训练脚本鲁棒性尤为重要。此外,您可能更喜欢使训练作业具有*弹性*,例如计算资源可以在作业期间动态加入和退出。

PyTorch提供了一种名为``torchrun``的工具,支持容错和弹性训练。当发生故障时,``torchrun``记录错误,并尝试从上次保存的训练作业“快照”自动重新启动所有进程。

快照不仅仅保存模型状态;它可以包括关于运行周期数、优化器状态或训练作业必要的其他属性的详细信息,以确保其连续性。

为什么使用``torchrun``

``torchrun``处理分布式训练的细节,让您不必费心。例如,

  • 您不需要设置环境变量或显式传递``rank``和``world_size``;``torchrun``会分配这些以及其他 环境变量

  • 无需在脚本中调用``mp.spawn``;您只需要一个通用的``main()``入口点,并使用``torchrun``启动脚本。这样,相同的脚本可以在非分布式、多节点和单节点设置中运行。

  • 优雅地从最后保存的训练快照重新启动训练。

优雅重启

为优雅重启,您应如以下方式结构化训练脚本:

def main():
  load_snapshot(snapshot_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_snapshot(snapshot_path)

如果出现故障,``torchrun``将终止所有进程并重新启动它们。每个进程入口点首先加载并初始化最后保存的快照,然后从那里继续训练。因此,在任何故障发生时,您只会失去最后保存快照以来的训练进度。

在弹性训练中,每当发生成员更改(添加或移除节点时),``torchrun``将终止并在可用设备上生成进程。具有这种结构可以确保您的训练作业能够继续进行,无需人工干预。

multigpu.pymultigpu_torchrun.py 的差异

进程组初始化

- def ddp_setup(rank, world_size):
+ def ddp_setup():
-     """
-     Args:
-         rank: Unique identifier of each process
-         world_size: Total number of processes
-     """
-     os.environ["MASTER_ADDR"] = "localhost"
-     os.environ["MASTER_PORT"] = "12355"
-     init_process_group(backend="nccl", rank=rank, world_size=world_size)
+     init_process_group(backend="nccl")
     torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

使用torchrun提供的环境变量

- self.gpu_id = gpu_id
+ self.gpu_id = int(os.environ["LOCAL_RANK"])

保存和加载快照

定期将所有相关信息存储在快照中,允许我们的训练作业在中断后无缝恢复。

+ def _save_snapshot(self, epoch):
+     snapshot = {}
+     snapshot["MODEL_STATE"] = self.model.module.state_dict()
+     snapshot["EPOCHS_RUN"] = epoch
+     torch.save(snapshot, "snapshot.pt")
+     print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")

+ def _load_snapshot(self, snapshot_path):
+     snapshot = torch.load(snapshot_path)
+     self.model.load_state_dict(snapshot["MODEL_STATE"])
+     self.epochs_run = snapshot["EPOCHS_RUN"]
+     print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

在Trainer构造函数中加载快照

当重新启动中断的训练作业时,您的脚本将首先尝试加载快照以继续训练。

class Trainer:
   def __init__(self, snapshot_path, ...):
   ...
+  if os.path.exists(snapshot_path):
+     self._load_snapshot(snapshot_path)
   ...

恢复训练

训练可以从最后运行的周期继续,而不是从零开始。

def train(self, max_epochs: int):
-  for epoch in range(max_epochs):
+  for epoch in range(self.epochs_run, max_epochs):
      self._run_epoch(epoch)

运行脚本

只需像非多进程脚本一样调用入口点函数即可;``torchrun``会自动生成进程。

if __name__ == "__main__":
   import sys
   total_epochs = int(sys.argv[1])
   save_every = int(sys.argv[2])
-  world_size = torch.cuda.device_count()
-  mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
+  main(save_every, total_epochs)
- python multigpu.py 50 10
+ torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源