CommDebugMode
入门¶
Created On: Aug 19, 2024 | Last Updated: Oct 08, 2024 | Last Verified: Nov 05, 2024
作者:Anshul Sinha
在本教程中,我们将探讨如何使用PyTorch的DistributedTensor(DTensor)进行 CommDebugMode
的调试,并跟踪分布式训练环境中的集体操作。
前提条件¶
Python 3.8 - 3.11
PyTorch 2.2 或更高版本
什么是 CommDebugMode
以及它为何有用¶
随着模型规模不断扩大,用户寻求利用各种并行策略组合来扩展分布式训练。然而,现有解决方案之间缺乏互操作性构成了一个显著挑战,主要是因为缺乏能够桥接这些不同并行策略的统一抽象。为解决这一问题,PyTorch提出了 DistributedTensor(DTensor),它抽象了分布式训练中张量通信的复杂性,提供了无缝的用户体验。然而,在使用现有并行解决方案以及利用像DTensor这样的统一抽象开发并行解决方案时,缺乏关于底层集体通信的透明性可能会使高级用户难以识别和解决问题。为解决这一问题,CommDebugMode
,一个Python上下文管理器,将成为DTensor的主要调试工具之一,使用户能够查看使用DTensor进行操作时的集体通信发生的时间和原因,从而有效解决这一问题。
使用 CommDebugMode
¶
以下是使用 CommDebugMode
的方法:
# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)
# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)
以下是一个噪声级别为0时的MLPModule输出示例:
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
要使用 CommDebugMode
,您需要将运行模型的代码包装在 CommDebugMode
中,并调用您希望用来显示数据的API。您还可以使用噪声级别参数来控制显示信息的详细程度。以下是各个噪声级别的显示内容:
在上面的示例中,您可以看到在 MLPModule
的前向传递中发生了一次集体操作 all_reduce。此外,您可以使用 CommDebugMode
来定位all-reduce操作发生在 MLPModule
的第二个线性层。
以下是您可以使用JSON导出上传的交互式模块树可视化:
结论¶
在这个教程中,我们学习了如何使用 CommDebugMode
来调试分布式张量和使用通信集成的并行解决方案。你可以在嵌入式可视化浏览器中使用自己的 JSON 输出。
有关 CommDebugMode
的详细信息,请参见 comm_mode_features_example.py