1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ (beta) Building a Convolution/Batch Norm fuser in FX
4
+ *******************************************************
5
+ **Author**: `Horace He <https://github.com/chillee>`_
6
+
7
+ In this tutorial, we are going to use FX, a toolkit for composable function
8
+ transformations of PyTorch, to do the following:
9
+
10
+ 1) Find patterns of conv/batch norm in the data dependencies.
11
+ 2) For the patterns found in 1), fold the batch norm statistics into the convolution weights.
12
+
13
+ Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)
14
+
15
+ We will be building the fuser that exists here:
16
+ https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
17
+
18
+ """
19
+
20
+
21
+ ######################################################################
22
+ # First, let's get some imports out of the way (we will be using all
23
+ # of these later in the code).
24
+
25
+ from typing import Type , Dict , Any , Tuple , Iterable
26
+ import copy
27
+ import torch .fx as fx
28
+ import torch
29
+ import torch .nn as nn
30
+
31
+ ######################################################################
32
+ # For this tutorial, we are going to create a model consisting of convolutions
33
+ # and batch norms. Note that this model has some tricky components - some of
34
+ # the conv/batch norm patterns are hidden within Sequentials and one of the
35
+ # BatchNorms is wrapped in another Module.
36
+
37
+ class WrappedBatchNorm (nn .Module ):
38
+ def __init__ (self ):
39
+ super ().__init__ ()
40
+ self .mod = nn .BatchNorm2d (1 )
41
+ def forward (self , x ):
42
+ return self .mod (x )
43
+
44
+ class M (nn .Module ):
45
+ def __init__ (self ):
46
+ super ().__init__ ()
47
+ self .conv1 = nn .Conv2d (1 , 1 , 1 )
48
+ self .bn1 = nn .BatchNorm2d (1 )
49
+ self .conv2 = nn .Conv2d (1 , 1 , 1 )
50
+ self .nested = nn .Sequential (
51
+ nn .BatchNorm2d (1 ),
52
+ nn .Conv2d (1 , 1 , 1 ),
53
+ )
54
+ self .wrapped = WrappedBatchnorm ()
55
+
56
+ def forward (self , x ):
57
+ x = self .conv1 (x )
58
+ x = self .bn1 (x )
59
+ x = self .conv2 (x )
60
+ x = self .nested (x )
61
+ x = self .wrapped (x )
62
+ return x
63
+
64
+ model = M ()
65
+
66
+ model .eval ()
67
+
68
+ ######################################################################
69
+ # Fusing Convolution with Batch Norm
70
+ # -----------------------------------------
71
+ # One of the primary challenges with trying to automatically fuse convolution
72
+ # and batch norm in PyTorch is that PyTorch does not provide an easy way of
73
+ # accessing the computational graph. FX resolves this problem by symbolically
74
+ # tracing the actual operations called, so that we can track the computations
75
+ # through the `forward` call, nested within Sequential modules, or wrapped in
76
+ # an user-defined module.
77
+
78
+ traced_model = torch .fx .symbolic_trace (model )
79
+ print (traced_model .graph )
80
+
81
+ ######################################################################
82
+ # This gives us a graph representation of our model. Note that both the modules
83
+ # hidden within the sequential as well as the wrapped Module have been inlined
84
+ # into the graph. This is the default level of abstraction, but it can be
85
+ # configured by the pass writer. More information can be found at the FX
86
+ # overview https://pytorch.org/docs/master/fx.html#module-torch.fx
87
+
88
+
89
+ ####################################
90
+ # Fusing Convolution with Batch Norm
91
+ # ----------------------------------
92
+ # Unlike some other fusions, fusion of convolution with batch norm does not
93
+ # require any new operators. Instead, as batch norm during inference
94
+ # consists of a pointwise add and multiply, these operations can be "baked"
95
+ # into the preceding convolution's weights. This allows us to remove the batch
96
+ # norm entirely from our model! Read
97
+ # https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The
98
+ # code here is copied from
99
+ # https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
100
+ # clarity purposes.
101
+ def fuse_conv_bn_eval (conv , bn ):
102
+ """
103
+ Given a conv Module `A` and an batch_norm module `B`, returns a conv
104
+ module `C` such that C(x) == B(A(x)) in inference mode.
105
+ """
106
+ assert (not (conv .training or bn .training )), "Fusion only for eval!"
107
+ fused_conv = copy .deepcopy (conv )
108
+
109
+ fused_conv .weight , fused_conv .bias = \
110
+ fuse_conv_bn_weights (fused_conv .weight , fused_conv .bias ,
111
+ bn .running_mean , bn .running_var , bn .eps , bn .weight , bn .bias )
112
+
113
+ return fused_conv
114
+
115
+ def fuse_conv_bn_weights (conv_w , conv_b , bn_rm , bn_rv , bn_eps , bn_w , bn_b ):
116
+ if conv_b is None :
117
+ conv_b = torch .zeros_like (bn_rm )
118
+ if bn_w is None :
119
+ bn_w = torch .ones_like (bn_rm )
120
+ if bn_b is None :
121
+ bn_b = torch .zeros_like (bn_rm )
122
+ bn_var_rsqrt = torch .rsqrt (bn_rv + bn_eps )
123
+
124
+ conv_w = conv_w * (bn_w * bn_var_rsqrt ).reshape ([- 1 ] + [1 ] * (len (conv_w .shape ) - 1 ))
125
+ conv_b = (conv_b - bn_rm ) * bn_var_rsqrt * bn_w + bn_b
126
+
127
+ return torch .nn .Parameter (conv_w ), torch .nn .Parameter (conv_b )
128
+
129
+
130
+ ####################################
131
+ # FX Fusion Pass
132
+ # ----------------------------------
133
+ # Now that we have our computational graph as well as a method for fusing
134
+ # convolution and batch norm, all that remains is to iterate over the FX graph
135
+ # and apply the desired fusions.
136
+
137
+
138
+ def _parent_name (target : str ) -> Tuple [str , str ]:
139
+ """
140
+ Splits a qualname into parent path and last atom.
141
+ For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
142
+ """
143
+ * parent , name = target .rsplit ('.' , 1 )
144
+ return parent [0 ] if parent else '' , name
145
+
146
+ def replace_node_module (node : fx .Node , modules : Dict [str , Any ], new_module : torch .nn .Module ):
147
+ assert (isinstance (node .target , str ))
148
+ parent_name , name = _parent_name (node .target )
149
+ setattr (modules [parent_name ], name , new_module )
150
+
151
+
152
+ def fuse (model : torch .nn .Module ) -> torch .nn .Module :
153
+ model = copy .deepcopy (model )
154
+ # The first step of most FX passes is to symbolically trace our model to
155
+ # obtain a `GraphModule`. This is a representation of our original model
156
+ # that is functionally identical to our original model, except that we now
157
+ # also have a graph representation of our forward pass.
158
+ fx_model : fx .GraphModule = fx .symbolic_trace (model )
159
+ modules = dict (fx_model .named_modules ())
160
+
161
+ # The primary representation for working with FX are the `Graph` and the
162
+ # `Node`. Each `GraphModule` has a `Graph` associated with it - this
163
+ # `Graph` is also what generates `GraphModule.code`.
164
+ # The `Graph` itself is represented as a list of `Node` objects. Thus, to
165
+ # iterate through all of the operations in our graph, we iterate over each
166
+ # `Node` in our `Graph`.
167
+ for node in fx_model .graph .nodes :
168
+ # The FX IR contains several types of nodes, which generally represent
169
+ # call sites to modules, functions, or methods. The type of node is
170
+ # determined by `Node.op`.
171
+ if node .op != 'call_module' : # If our current node isn't calling a Module then we can ignore it.
172
+ continue
173
+ # For call sites, `Node.target` represents the module/function/method
174
+ # that's being called. Here, we check `Node.target` to see if it's a
175
+ # batch norm module, and then check `Node.args[0].target` to see if the
176
+ # input `Node` is a convolution.
177
+ if type (modules [node .target ]) is nn .BatchNorm2d and type (modules [node .args [0 ].target ]) is nn .Conv2d :
178
+ if len (node .args [0 ].users ) > 1 : # Output of conv is used by other nodes
179
+ continue
180
+ conv = modules [node .args [0 ].target ]
181
+ bn = modules [node .target ]
182
+ fused_conv = fuse_conv_bn_eval (conv , bn )
183
+ replace_node_module (node .args [0 ], modules , fused_conv )
184
+ # As we've folded the batch nor into the conv, we need to replace all uses
185
+ # of the batch norm with the conv.
186
+ node .replace_all_uses_with (node .args [0 ])
187
+ # Now that all uses of the batch norm have been replaced, we can
188
+ # safely remove the batch norm.
189
+ fx_model .graph .erase_node (node )
190
+ fx_model .graph .lint ()
191
+ # After we've modified our graph, we need to recompile our graph in order
192
+ # to keep the generated code in sync.
193
+ fx_model .recompile ()
194
+ return fx_model
195
+
196
+
197
+ ######################################################################
198
+ # .. note::
199
+ # We make some simplifications here for demonstration purposes, such as only
200
+ # matching 2D convolutions. View
201
+ # https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
202
+ # for a more usable pass.
203
+
204
+ ######################################################################
205
+ # Testing out our Fusion Pass
206
+ # -----------------------------------------
207
+ # We can now run this fusion pass on our initial toy model and verify that our
208
+ # results are identical. In addition, we can print out the code for our fused
209
+ # model and verify that there are no more batch norms.
210
+
211
+
212
+ fused_model = fuse (model )
213
+ print (fused_model .code )
214
+ inp = torch .randn (5 , 1 , 1 , 1 )
215
+ torch .testing .assert_allclose (fused_model (inp ), model (inp ))
216
+
217
+
218
+ ######################################################################
219
+ # Benchmarking our Fusion on ResNet18
220
+ # ----------
221
+ # We can test our fusion pass on a larger model like ResNet18 and see how much
222
+ # this pass improves inference performance.
223
+ import torchvision .models as models
224
+ import time
225
+
226
+ rn18 = models .resnet18 ()
227
+ rn18 .eval ()
228
+
229
+ inp = torch .randn (10 , 3 , 224 , 224 )
230
+ output = rn18 (inp )
231
+
232
+ def benchmark (model , iters = 20 ):
233
+ for _ in range (10 ):
234
+ model (inp )
235
+ begin = time .time ()
236
+ for _ in range (iters ):
237
+ model (inp )
238
+ return str (time .time ()- begin )
239
+
240
+ fused_rn18 = fuse (rn18 )
241
+ print ("Unfused time: " , benchmark (rn18 ))
242
+ print ("Fused time: " , benchmark (fused_rn18 ))
243
+ ######################################################################
244
+ # As we previously saw, the output of our FX transformation is
245
+ # (Torchscriptable) PyTorch code, we can easily `jit.script` the output to try
246
+ # and increase our performance even more. In this way, our FX model
247
+ # transformation composes with Torchscript with no issues.
248
+ jit_rn18 = torch .jit .script (fused_rn18 )
249
+ print ("jit time: " , benchmark (jit_rn18 ))
250
+
251
+
252
+ ############
253
+ # Conclusion
254
+ # ----------
255
+ # As we can see, using FX we can easily write static graph transformations on
256
+ # PyTorch code.
257
+ #
258
+ # Since FX is still in beta, we would be happy to hear any
259
+ # feedback you have about using it. Please feel free to use the
260
+ # PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
261
+ # (https://github.com/pytorch/pytorch/issues) to provide any feedback
262
+ # you might have.
0 commit comments