Shortcuts

CommDebugMode 入门

Created On: Aug 19, 2024 | Last Updated: Oct 08, 2024 | Last Verified: Nov 05, 2024

作者Anshul Sinha

在本教程中,我们将探讨如何使用PyTorch的DistributedTensor(DTensor)进行 CommDebugMode 的调试,并跟踪分布式训练环境中的集体操作。

前提条件

  • Python 3.8 - 3.11

  • PyTorch 2.2 或更高版本

什么是 CommDebugMode 以及它为何有用

随着模型规模不断扩大,用户寻求利用各种并行策略组合来扩展分布式训练。然而,现有解决方案之间缺乏互操作性构成了一个显著挑战,主要是因为缺乏能够桥接这些不同并行策略的统一抽象。为解决这一问题,PyTorch提出了 DistributedTensor(DTensor),它抽象了分布式训练中张量通信的复杂性,提供了无缝的用户体验。然而,在使用现有并行解决方案以及利用像DTensor这样的统一抽象开发并行解决方案时,缺乏关于底层集体通信的透明性可能会使高级用户难以识别和解决问题。为解决这一问题,CommDebugMode,一个Python上下文管理器,将成为DTensor的主要调试工具之一,使用户能够查看使用DTensor进行操作时的集体通信发生的时间和原因,从而有效解决这一问题。

使用 CommDebugMode

以下是使用 CommDebugMode 的方法:

# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
    with comm_mode:
        output = model(inp)

# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
    noise_level=1, file_name="transformer_operation_log.txt"
)

# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)

以下是一个噪声级别为0时的MLPModule输出示例:

Expected Output:
    Global
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        MLPModule
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            MLPModule.net1
            MLPModule.relu
            MLPModule.net2
              FORWARD PASS
                *c10d_functional.all_reduce: 1

要使用 CommDebugMode,您需要将运行模型的代码包装在 CommDebugMode 中,并调用您希望用来显示数据的API。您还可以使用噪声级别参数来控制显示信息的详细程度。以下是各个噪声级别的显示内容:

0. Prints module-level collective counts
1. Prints DTensor operations (not including trivial operations), module sharding information
2. Prints tensor operations (not including trivial operations)
3. Prints all operations

在上面的示例中,您可以看到在 MLPModule 的前向传递中发生了一次集体操作 all_reduce。此外,您可以使用 CommDebugMode 来定位all-reduce操作发生在 MLPModule 的第二个线性层。

以下是您可以使用JSON导出上传的交互式模块树可视化:

CommDebugMode Module Tree
Drag file here

结论

在这个教程中,我们学习了如何使用 CommDebugMode 来调试分布式张量和使用通信集成的并行解决方案。你可以在嵌入式可视化浏览器中使用自己的 JSON 输出。

有关 CommDebugMode 的详细信息,请参见 comm_mode_features_example.py

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源