(测试版)使用 FX 构建简单的 CPU 性能分析器
作者: James Reed
在本教程中,我们将使用 FX 来完成以下任务:
-
以能够检查和收集代码结构及执行统计信息的方式捕获 PyTorch Python 代码
-
构建一个小型类,作为简单的性能“分析器”,从实际运行中收集模型每个部分的运行时统计信息。
在本教程中,我们将使用 torchvision 的 ResNet18 模型进行演示。
importtorch
importtorch.fx
importtorchvision.modelsasmodels
rn18 = models.resnet18()
rn18.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
现在我们已经有了模型,我们希望更深入地检查其性能。也就是说,对于以下调用,模型的哪些部分耗时最长?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
一种常见的回答该问题的方法是遍历程序源代码,添加在程序中各个点收集时间戳的代码,并比较这些时间戳之间的差异,以查看时间戳之间的区域所花费的时间。
这种技术当然适用于 PyTorch 代码,但如果我们不必复制模型代码并对其进行编辑,尤其是我们未编写的代码(如这个 torchvision 模型),那将更好。相反,我们将使用 FX 来自动化这个“插桩”过程,而无需修改任何源代码。
首先,让我们先完成一些导入(我们稍后将在代码中使用所有这些内容)。
importstatistics,tabulate,time
fromtypingimport Any, Dict, List
fromtorch.fximport Interpreter
tabulate
是一个外部库,并非 PyTorch 的依赖项。我们将使用它来更轻松地可视化性能数据。请确保您已从您喜欢的 Python 包源中安装它。
使用符号追踪捕获模型
接下来,我们将使用 FX 的符号追踪机制,将模型的定义捕获到一个我们可以操作和检查的数据结构中。
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
%x : torch.Tensor [num_users=1] = placeholder[target=x]
%conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
%relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
%maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
%layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
%layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
%layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
%layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
%add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
%layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
%layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
%layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
%layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
%layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
%layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
%layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
%layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
%layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
%layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
%layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
%layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
%layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
%layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
%layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
%layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
%layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
%layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
%layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
%layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
%layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
%layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
%flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
%fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
return fc
这为我们提供了 ResNet18 模型的图表示。图由一系列相互连接的节点组成。每个节点表示 Python 代码中的一个调用点(无论是函数、模块还是方法),而边(表示为每个节点上的 args
和 kwargs
)表示在这些调用点之间传递的值。有关图表示以及 FX 其他 API 的更多信息,可以在 FX 文档 中找到。
创建一个性能分析解释器
接下来,我们将创建一个继承自torch.fx.Interpreter
的类。尽管symbolic_trace
生成的GraphModule
在您调用GraphModule
时会编译并运行Python代码,但另一种运行GraphModule
的方式是逐个执行Graph
中的每个Node
。这正是Interpreter
提供的功能:它逐个节点地解释图。
通过继承Interpreter
,我们可以重写各种功能并安装我们想要的性能分析行为。目标是拥有一个对象,我们可以向其传递一个模型,调用模型一次或多次,然后获取有关模型及其各部分在这些运行期间所花费时间的统计信息。
让我们定义我们的ProfilingInterpreter
类:
classProfilingInterpreter(Interpreter):
def__init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)
# We are going to store away two things here:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
defrun(self, *args) -> Any:
# Record the time we started running the model
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
return_val = super().run(*args)
# Record the time we finished running the model
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
defrun_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
return_val = super().run_node(n)
# Record the time we finished running the op
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
defsummary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# For each node, record summary statistics
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
我们使用 Python 的
time.time
函数来获取系统时间戳并进行比较。这并不是最精确的性能测量方法,只能提供一个初步的近似值。在本教程中,我们仅出于演示目的而使用这种简单技术。
探究 ResNet18 的性能
现在我们可以使用 ProfilingInterpreter
来检查 ResNet18 模型的性能特征。
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime
*------------ --------------------- --------------------- -------------------
call_module maxpool 0.00581646 8.75551
call_module conv1 0.00506926 7.63074
call_module layer4_0_conv2 0.00377035 5.6755
call_module layer1_0_conv1 0.00349903 5.26709
call_module layer1_1_conv1 0.00343132 5.16516
call_module layer4_1_conv2 0.0031662 4.76607
call_module layer4_1_conv1 0.00313568 4.72014
call_module layer2_1_conv1 0.00297475 4.47789
call_module layer2_1_conv2 0.00297022 4.47107
call_module layer3_1_conv2 0.0028255 4.25322
call_module layer1_0_conv2 0.00277519 4.17749
call_module layer1_1_conv2 0.00271678 4.08956
call_module layer3_1_conv1 0.0026722 4.02245
call_module layer3_0_conv2 0.0026114 3.93093
call_module layer2_0_conv2 0.00242352 3.64813
call_module layer4_0_conv1 0.00229311 3.45182
call_module layer2_0_conv1 0.00184298 2.77423
call_module bn1 0.00163937 2.46774
call_module layer3_0_conv1 0.00157094 2.36473
call_module layer2_0_downsample_0 0.00106168 1.59814
call_module layer4_0_downsample_0 0.000489473 0.736804
call_module layer3_0_downsample_0 0.000485659 0.731061
call_function add 0.000455618 0.685841
call_function add_1 0.000450611 0.678304
call_function add_3 0.000293255 0.441436
call_module relu 0.000248671 0.374323
call_module layer1_0_bn1 0.000205278 0.309005
call_module fc 0.000189781 0.285677
call_module layer1_1_bn1 0.000188351 0.283524
call_module layer1_0_bn2 0.000165939 0.249788
call_module avgpool 0.000165939 0.249788
call_module layer1_1_bn2 0.000163555 0.246199
call_module layer2_0_downsample_1 0.000151396 0.227896
call_module layer2_1_bn1 0.000146866 0.221077
call_module layer2_1_bn2 0.000146151 0.22
call_module layer3_0_bn2 0.000144243 0.217129
call_module layer2_0_bn1 0.000142574 0.214617
call_module layer3_1_bn1 0.000138283 0.208157
call_module layer3_1_bn2 0.000137806 0.207439
call_module layer2_0_bn2 0.000135183 0.203491
call_module layer4_1_bn1 0.000133753 0.201338
call_module layer4_0_bn1 0.000129223 0.194519
call_module layer4_0_bn2 0.000128984 0.19416
call_module layer4_1_bn2 0.000128269 0.193083
call_module layer3_0_bn1 0.000126362 0.190212
call_module layer3_0_downsample_1 0.000126362 0.190212
call_module layer4_0_downsample_1 0.000125885 0.189495
call_module layer1_0_relu 0.000117779 0.177292
call_module layer1_1_relu 0.000109196 0.164372
call_module layer1_0_relu_1 0.000107527 0.16186
call_module layer1_1_relu_1 0.000106573 0.160424
call_module layer4_1_relu 9.13143e-05 0.137455
call_module layer4_0_relu 9.08375e-05 0.136738
call_module layer2_1_relu 8.98838e-05 0.135302
call_module layer2_0_relu 8.96454e-05 0.134943
call_function add_2 8.86917e-05 0.133508
call_module layer2_0_relu_1 8.82149e-05 0.13279
call_module layer2_1_relu_1 8.70228e-05 0.130995
call_module layer3_1_relu 8.29697e-05 0.124894
call_module layer3_0_relu 8.13007e-05 0.122382
call_module layer4_1_relu_1 7.93934e-05 0.119511
call_module layer4_0_relu_1 7.9155e-05 0.119152
call_module layer3_1_relu_1 7.89165e-05 0.118793
call_function add_7 7.86781e-05 0.118434
call_function add_5 7.84397e-05 0.118075
call_module layer3_0_relu_1 7.79629e-05 0.117357
call_function add_6 7.34329e-05 0.110538
call_function add_4 7.05719e-05 0.106232
call_function flatten 5.17368e-05 0.0778794
placeholder x 3.00407e-05 0.0452203
output output 2.00272e-05 0.0301469
这里有两件事需要特别指出:
-
MaxPool2d
占用最多的时间。这是一个已知问题:https://github.com/pytorch/pytorch/issues/51393 -
BatchNorm2d 也占用了大量时间。我们可以继续沿着这一思路,在 Conv-BN 融合与 FX 教程 中进行优化。
结论
正如我们所看到的,使用 FX 我们可以轻松地将 PyTorch 程序(即使是我们没有源代码的程序!)捕获为机器可解释的格式,并用于分析,例如我们在这里进行的性能分析。FX 为处理 PyTorch 程序开启了一个充满可能性的新世界。
最后,由于 FX 仍处于测试阶段,我们非常乐意听取您使用它的任何反馈。请随时在 PyTorch 论坛(https://discuss.pytorch.org/)和问题跟踪器(https://github.com/pytorch/pytorch/issues)上提供您的反馈意见。