备注
点击:ref:`这里 <sphx_glr_download_intermediate_fx_profiling_tutorial.py>`下载完整示例代码
(Beta) 使用FX构建简单的CPU性能分析工具¶
Created On: Mar 04, 2021 | Last Updated: Jan 16, 2024 | Last Verified: Not Verified
作者: James Reed
在本教程中,我们将使用FX完成以下任务:
以一种我们可以检查和收集代码结构和执行统计数据的方式捕获PyTorch Python代码
构建一个小型类,用作简单性能“分析器”,收集模型每一部分在实际运行中的运行时统计数据。
在本教程中,我们将使用torchvision的ResNet18模型作为演示案例。
import torch
import torch.fx
import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
现在我们得到了模型,我们希望深入研究其性能。也就是说,在下面的调用中,模型的哪些部分花费的时间最长?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
回答这个问题的一种常见方法是遍历程序源代码,在程序中不同点添加收集时间戳的代码,并比较这些时间戳之间的差异以查看这些区域消耗的时间。
这种技术当然适用于PyTorch代码,但如果我们不必复制模型代码并进行编辑,尤其是对于我们没有编写的代码(比如这个torchvision模型),这将会更好。相反,我们将使用FX自动完成此“插装”过程,而无需修改任何源码。
首先,让我们处理一些导入操作(稍后代码中将会用到所有这些内容)。
import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter
备注
``tabulate``是一个外部库,不是PyTorch的必需依赖项。我们将使用它更轻松地可视化性能数据。请确保您已从最喜欢的Python软件包源安装它。
使用符号跟踪捕获模型¶
接下来,我们将使用FX的符号跟踪机制,以便在可以操控和检查的数据结构中捕获模型的定义。
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
%x : torch.Tensor [num_users=1] = placeholder[target=x]
%conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
%relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
%maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
%layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
%layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
%layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
%layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
%add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
%layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
%layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
%layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
%layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
%layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
%layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
%layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
%layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
%layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
%layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
%layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
%layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
%layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
%layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
%layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
%layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
%layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
%layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
%layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
%layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
%layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
%layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
%flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
%fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
return fc
这为我们提供了ResNet18模型的图结构表示。图由一系列连接的节点组成。每个节点表示Python代码中的调用位置(无论是函数、模块还是方法),边(在每个节点上的``args``和``kwargs``表示)则表示这些调用位置之间传递的值。关于图结构表示及FX的其他API的更多信息,可以在FX文档中找到:https://pytorch.org/docs/master/fx.html。
创建性能分析解释器¶
接下来,我们将创建一个继承自``torch.fx.Interpreter``的类。虽然``symbolic_trace``生产的`GraphModule`编译了调用`GraphModule`时运行的Python代码,另一种运行`GraphModule`的方法是逐一执行图中的每个``Node``。``Interpreter``提供了这种功能:它逐节点解释图。
通过继承``Interpreter``,我们可以覆盖各种功能并安装性能分析所需的行为。目标是提供一个对象,将模型传递给该对象后,可以调用模型1次或多次,然后获取关于模型及其每部分在这些运行中耗时的统计数据。
定义我们的``ProfilingInterpreter``类:
class ProfilingInterpreter(Interpreter):
def __init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)
# We are going to store away two things here:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
def run(self, *args) -> Any:
# Record the time we started running the model
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
return_val = super().run(*args)
# Record the time we finished running the model
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
def run_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
return_val = super().run_node(n)
# Record the time we finished running the op
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
def summary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# For each node, record summary statistics
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
备注
我们使用Python的``time.time``函数获取墙时钟时间戳并进行比较。这不是测量性能最准确的方法,仅会提供一个一级近似值。我们仅为本教程的演示目的使用此简单技术。
研究ResNet18的性能¶
我们现在可以使用``ProfilingInterpreter``检查ResNet18模型的性能特性;
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime
------------- --------------------- --------------------- -------------------
call_module conv1 0.00257397 9.01507
call_module layer1_0_conv1 0.00206637 7.23728
call_module layer1_0_conv2 0.00189829 6.64857
call_module layer1_1_conv2 0.00184464 6.46069
call_module maxpool 0.00156546 5.48286
call_module bn1 0.00152111 5.32754
call_module layer2_1_conv2 0.00139427 4.8833
call_module layer2_0_downsample_0 0.00121236 4.24617
call_module layer1_1_conv1 0.000962734 3.37188
call_module layer4_0_conv2 0.000887871 3.10968
call_module layer4_1_conv2 0.000799656 2.80072
call_module layer2_1_conv1 0.000715494 2.50595
call_module layer3_1_conv1 0.000664473 2.32725
call_module layer4_0_conv1 0.000620365 2.17277
call_module layer4_1_conv1 0.000568867 1.9924
call_module layer3_1_conv2 0.000558853 1.95733
call_module layer3_0_downsample_0 0.000512838 1.79617
call_module layer2_0_conv2 0.000511408 1.79116
call_module layer3_0_conv2 0.000474453 1.66173
call_module layer3_0_conv1 0.000458479 1.60578
call_function add 0.000437498 1.5323
call_module layer2_0_conv1 0.000427961 1.49889
call_function add_1 0.000420332 1.47217
call_module layer4_0_downsample_0 0.00035429 1.24087
call_function add_2 0.000230789 0.808317
call_module relu 0.000218153 0.76406
call_function add_3 0.000208378 0.729823
call_module avgpool 0.000161171 0.564486
call_module layer1_0_bn1 0.000131369 0.460106
call_module layer2_1_bn2 0.000116348 0.407499
call_module fc 0.000115395 0.404158
call_module layer1_0_bn2 0.000101805 0.356561
call_module layer1_1_bn2 9.94205e-05 0.348211
call_module layer4_0_relu 9.77516e-05 0.342366
call_module layer1_1_bn1 9.53674e-05 0.334015
call_module layer1_0_relu 8.98838e-05 0.314809
call_module layer4_1_bn2 8.84533e-05 0.309799
call_module layer2_1_relu 8.7738e-05 0.307294
call_function add_5 8.70228e-05 0.304789
call_function add_4 8.36849e-05 0.293098
call_module layer2_0_bn1 8.29697e-05 0.290593
call_module layer2_0_downsample_1 8.10623e-05 0.283913
call_module layer1_1_relu 8.05855e-05 0.282243
call_module layer3_0_bn1 7.98702e-05 0.279738
call_module layer4_1_bn1 7.89165e-05 0.276398
call_module layer2_0_relu 7.77245e-05 0.272222
call_function add_7 7.7486e-05 0.271387
call_module layer4_0_bn1 7.65324e-05 0.268047
call_module layer4_0_bn2 7.62939e-05 0.267212
call_module layer4_0_downsample_1 7.62939e-05 0.267212
call_module layer3_0_relu 7.60555e-05 0.266377
call_module layer2_0_bn2 7.53403e-05 0.263872
call_module layer3_1_bn2 7.53403e-05 0.263872
call_module layer2_1_bn1 7.43866e-05 0.260532
call_module layer4_1_relu 7.41482e-05 0.259697
call_module layer3_0_downsample_1 7.36713e-05 0.258027
call_module layer3_0_bn2 7.15256e-05 0.250511
call_module layer3_1_bn1 7.05719e-05 0.247171
call_module layer3_1_relu 6.8903e-05 0.241326
call_function add_6 6.77109e-05 0.237151
call_module layer2_0_relu_1 6.10352e-05 0.21377
call_module layer4_0_relu_1 6.00815e-05 0.21043
call_module layer1_0_relu_1 5.88894e-05 0.206254
call_module layer4_1_relu_1 5.81741e-05 0.203749
call_module layer1_1_relu_1 5.76973e-05 0.202079
call_module layer2_1_relu_1 5.26905e-05 0.184543
call_module layer3_1_relu_1 5.26905e-05 0.184543
call_module layer3_0_relu_1 5.05447e-05 0.177028
placeholder x 2.95639e-05 0.103545
call_function flatten 2.5034e-05 0.087679
output output 7.62939e-06 0.0267212
这里有两件事值得注意:
``MaxPool2d``占用了最多的时间。这是一个已知问题:https://github.com/pytorch/pytorch/issues/51393
BatchNorm2d也占用了大量时间。我们可以继续这一思路,并通过FX教程优化卷积-BN融合:https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html。
结论¶
正如我们所见,使用FX我们可以轻松以机器可解读的格式捕获PyTorch程序(甚至是没有源码的程序!),并用于分析,例如我们这里进行的性能分析。FX为处理PyTorch程序开启了一个令人兴奋的可能性世界。
最后,由于FX仍处于beta阶段,我们很乐意收到有关使用它的任何反馈。请随时使用PyTorch论坛(https://discuss.pytorch.org/)和问题追踪器(https://github.com/pytorch/pytorch/issues)提供任何反馈。
Total running time of the script: ( 0 minutes 0.269 seconds)