Skip to content

Commit 5bda6b0

Browse files
Chilleebrianjo
andauthored
[FX] Added fuser tutorial (#1356)
* Added fuser tutorial * updated index.rst * fixed conclusion * responded to some comments * responded to comments * respond Co-authored-by: Brian Johnson <brianjo@fb.com>
1 parent 63bfc84 commit 5bda6b0

File tree

2 files changed

+269
-1
lines changed

2 files changed

+269
-1
lines changed

index.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ Welcome to PyTorch Tutorials
217217

218218
.. Code Transformations with FX
219219
220+
.. customcarditem::
221+
:header: Building a Convolution/Batch Norm fuser in FX
222+
:card_description: Build a simple FX pass that fuses batch norm into convolution to improve performance during inference.
223+
:image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png
224+
:link: intermediate/fx_conv_bn_fuser.html
225+
:tags: FX
226+
220227
.. customcarditem::
221228
:header: Building a Simple Performance Profiler with FX
222229
:card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics
@@ -614,4 +621,3 @@ Additional Resources
614621

615622
beginner/deeplabv3_on_ios
616623
beginner/deeplabv3_on_android
617-
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
(beta) Building a Convolution/Batch Norm fuser in FX
4+
*******************************************************
5+
**Author**: `Horace He <https://github.com/chillee>`_
6+
7+
In this tutorial, we are going to use FX, a toolkit for composable function
8+
transformations of PyTorch, to do the following:
9+
10+
1) Find patterns of conv/batch norm in the data dependencies.
11+
2) For the patterns found in 1), fold the batch norm statistics into the convolution weights.
12+
13+
Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)
14+
15+
We will be building the fuser that exists here:
16+
https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
17+
18+
"""
19+
20+
21+
######################################################################
22+
# First, let's get some imports out of the way (we will be using all
23+
# of these later in the code).
24+
25+
from typing import Type, Dict, Any, Tuple, Iterable
26+
import copy
27+
import torch.fx as fx
28+
import torch
29+
import torch.nn as nn
30+
31+
######################################################################
32+
# For this tutorial, we are going to create a model consisting of convolutions
33+
# and batch norms. Note that this model has some tricky components - some of
34+
# the conv/batch norm patterns are hidden within Sequentials and one of the
35+
# BatchNorms is wrapped in another Module.
36+
37+
class WrappedBatchNorm(nn.Module):
38+
def __init__(self):
39+
super().__init__()
40+
self.mod = nn.BatchNorm2d(1)
41+
def forward(self, x):
42+
return self.mod(x)
43+
44+
class M(nn.Module):
45+
def __init__(self):
46+
super().__init__()
47+
self.conv1 = nn.Conv2d(1, 1, 1)
48+
self.bn1 = nn.BatchNorm2d(1)
49+
self.conv2 = nn.Conv2d(1, 1, 1)
50+
self.nested = nn.Sequential(
51+
nn.BatchNorm2d(1),
52+
nn.Conv2d(1, 1, 1),
53+
)
54+
self.wrapped = WrappedBatchnorm()
55+
56+
def forward(self, x):
57+
x = self.conv1(x)
58+
x = self.bn1(x)
59+
x = self.conv2(x)
60+
x = self.nested(x)
61+
x = self.wrapped(x)
62+
return x
63+
64+
model = M()
65+
66+
model.eval()
67+
68+
######################################################################
69+
# Fusing Convolution with Batch Norm
70+
# -----------------------------------------
71+
# One of the primary challenges with trying to automatically fuse convolution
72+
# and batch norm in PyTorch is that PyTorch does not provide an easy way of
73+
# accessing the computational graph. FX resolves this problem by symbolically
74+
# tracing the actual operations called, so that we can track the computations
75+
# through the `forward` call, nested within Sequential modules, or wrapped in
76+
# an user-defined module.
77+
78+
traced_model = torch.fx.symbolic_trace(model)
79+
print(traced_model.graph)
80+
81+
######################################################################
82+
# This gives us a graph representation of our model. Note that both the modules
83+
# hidden within the sequential as well as the wrapped Module have been inlined
84+
# into the graph. This is the default level of abstraction, but it can be
85+
# configured by the pass writer. More information can be found at the FX
86+
# overview https://pytorch.org/docs/master/fx.html#module-torch.fx
87+
88+
89+
####################################
90+
# Fusing Convolution with Batch Norm
91+
# ----------------------------------
92+
# Unlike some other fusions, fusion of convolution with batch norm does not
93+
# require any new operators. Instead, as batch norm during inference
94+
# consists of a pointwise add and multiply, these operations can be "baked"
95+
# into the preceding convolution's weights. This allows us to remove the batch
96+
# norm entirely from our model! Read
97+
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The
98+
# code here is copied from
99+
# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
100+
# clarity purposes.
101+
def fuse_conv_bn_eval(conv, bn):
102+
"""
103+
Given a conv Module `A` and an batch_norm module `B`, returns a conv
104+
module `C` such that C(x) == B(A(x)) in inference mode.
105+
"""
106+
assert(not (conv.training or bn.training)), "Fusion only for eval!"
107+
fused_conv = copy.deepcopy(conv)
108+
109+
fused_conv.weight, fused_conv.bias = \
110+
fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
111+
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
112+
113+
return fused_conv
114+
115+
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
116+
if conv_b is None:
117+
conv_b = torch.zeros_like(bn_rm)
118+
if bn_w is None:
119+
bn_w = torch.ones_like(bn_rm)
120+
if bn_b is None:
121+
bn_b = torch.zeros_like(bn_rm)
122+
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
123+
124+
conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
125+
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
126+
127+
return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)
128+
129+
130+
####################################
131+
# FX Fusion Pass
132+
# ----------------------------------
133+
# Now that we have our computational graph as well as a method for fusing
134+
# convolution and batch norm, all that remains is to iterate over the FX graph
135+
# and apply the desired fusions.
136+
137+
138+
def _parent_name(target : str) -> Tuple[str, str]:
139+
"""
140+
Splits a qualname into parent path and last atom.
141+
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
142+
"""
143+
*parent, name = target.rsplit('.', 1)
144+
return parent[0] if parent else '', name
145+
146+
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
147+
assert(isinstance(node.target, str))
148+
parent_name, name = _parent_name(node.target)
149+
setattr(modules[parent_name], name, new_module)
150+
151+
152+
def fuse(model: torch.nn.Module) -> torch.nn.Module:
153+
model = copy.deepcopy(model)
154+
# The first step of most FX passes is to symbolically trace our model to
155+
# obtain a `GraphModule`. This is a representation of our original model
156+
# that is functionally identical to our original model, except that we now
157+
# also have a graph representation of our forward pass.
158+
fx_model: fx.GraphModule = fx.symbolic_trace(model)
159+
modules = dict(fx_model.named_modules())
160+
161+
# The primary representation for working with FX are the `Graph` and the
162+
# `Node`. Each `GraphModule` has a `Graph` associated with it - this
163+
# `Graph` is also what generates `GraphModule.code`.
164+
# The `Graph` itself is represented as a list of `Node` objects. Thus, to
165+
# iterate through all of the operations in our graph, we iterate over each
166+
# `Node` in our `Graph`.
167+
for node in fx_model.graph.nodes:
168+
# The FX IR contains several types of nodes, which generally represent
169+
# call sites to modules, functions, or methods. The type of node is
170+
# determined by `Node.op`.
171+
if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
172+
continue
173+
# For call sites, `Node.target` represents the module/function/method
174+
# that's being called. Here, we check `Node.target` to see if it's a
175+
# batch norm module, and then check `Node.args[0].target` to see if the
176+
# input `Node` is a convolution.
177+
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
178+
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
179+
continue
180+
conv = modules[node.args[0].target]
181+
bn = modules[node.target]
182+
fused_conv = fuse_conv_bn_eval(conv, bn)
183+
replace_node_module(node.args[0], modules, fused_conv)
184+
# As we've folded the batch nor into the conv, we need to replace all uses
185+
# of the batch norm with the conv.
186+
node.replace_all_uses_with(node.args[0])
187+
# Now that all uses of the batch norm have been replaced, we can
188+
# safely remove the batch norm.
189+
fx_model.graph.erase_node(node)
190+
fx_model.graph.lint()
191+
# After we've modified our graph, we need to recompile our graph in order
192+
# to keep the generated code in sync.
193+
fx_model.recompile()
194+
return fx_model
195+
196+
197+
######################################################################
198+
# .. note::
199+
# We make some simplifications here for demonstration purposes, such as only
200+
# matching 2D convolutions. View
201+
# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
202+
# for a more usable pass.
203+
204+
######################################################################
205+
# Testing out our Fusion Pass
206+
# -----------------------------------------
207+
# We can now run this fusion pass on our initial toy model and verify that our
208+
# results are identical. In addition, we can print out the code for our fused
209+
# model and verify that there are no more batch norms.
210+
211+
212+
fused_model = fuse(model)
213+
print(fused_model.code)
214+
inp = torch.randn(5, 1, 1, 1)
215+
torch.testing.assert_allclose(fused_model(inp), model(inp))
216+
217+
218+
######################################################################
219+
# Benchmarking our Fusion on ResNet18
220+
# ----------
221+
# We can test our fusion pass on a larger model like ResNet18 and see how much
222+
# this pass improves inference performance.
223+
import torchvision.models as models
224+
import time
225+
226+
rn18 = models.resnet18()
227+
rn18.eval()
228+
229+
inp = torch.randn(10, 3, 224, 224)
230+
output = rn18(inp)
231+
232+
def benchmark(model, iters=20):
233+
for _ in range(10):
234+
model(inp)
235+
begin = time.time()
236+
for _ in range(iters):
237+
model(inp)
238+
return str(time.time()-begin)
239+
240+
fused_rn18 = fuse(rn18)
241+
print("Unfused time: ", benchmark(rn18))
242+
print("Fused time: ", benchmark(fused_rn18))
243+
######################################################################
244+
# As we previously saw, the output of our FX transformation is
245+
# (Torchscriptable) PyTorch code, we can easily `jit.script` the output to try
246+
# and increase our performance even more. In this way, our FX model
247+
# transformation composes with Torchscript with no issues.
248+
jit_rn18 = torch.jit.script(fused_rn18)
249+
print("jit time: ", benchmark(jit_rn18))
250+
251+
252+
############
253+
# Conclusion
254+
# ----------
255+
# As we can see, using FX we can easily write static graph transformations on
256+
# PyTorch code.
257+
#
258+
# Since FX is still in beta, we would be happy to hear any
259+
# feedback you have about using it. Please feel free to use the
260+
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
261+
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
262+
# you might have.

0 commit comments

Comments
 (0)