Skip to content

Commit 6ce4bb0

Browse files
authored
[quant] Guide on using FX Graph Mode Quantization and symbolic trace (#1312)
* [quant] Guide on using FX Graph Mode Quantization and symbolic trace * [quant] Guide on using FX Graph Mode Quantization and symbolic trace * Update user guide title
1 parent c5794c9 commit 6ce4bb0

File tree

2 files changed

+348
-4
lines changed

2 files changed

+348
-4
lines changed

prototype_source/README.txt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@ Prototype Tutorials
2424
Vulkan Backend User Workflow
2525
https://pytorch.org/tutorials/intermediate/vulkan_workflow.html
2626

27-
7. fx_graph_mode_static_quantization.py
27+
7. fx_graph_mode_ptq_static.py
2828
FX Graph Mode Post Training Static Quantization
29-
https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static_tutorial.html
29+
https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html
3030

31-
8. fx_graph_mode_dynamic_quantization.py
31+
8. fx_graph_mode_ptq_dynamic.py
3232
FX Graph Mode Post Training Dynamic Quantization
33-
https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic_tutorial.html
33+
https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_dynamic.html
34+
35+
9. fx_graph_mode_quant_guide.py
36+
FX Graph Mode Quantization User Guide
37+
https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html
38+
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
(prototype) FX Graph Mode Quantization User Guide
4+
===========================================================
5+
6+
**Author**: `Jerry Zhang <https://github.com/jerryzh168>`_
7+
8+
FX Graph Mode Quantization requires a symbolically traceable model.
9+
We use the FX framework (TODO: link) to convert a symbolically traceable nn.Module instance to IR,
10+
and we operate on the IR to execute the quantization passes.
11+
Please post your question about symbolically tracing your model in `PyTorch Discussion Forum <https://discuss.pytorch.org/c/quantization/17>`_
12+
13+
Quantization will only work on the symbolically traceable parts of your model.
14+
Data dependent control flow (if statements / for loops etc using symbolically traced values) are one common pattern which is not supported.
15+
If your model is not symbolically traceable end to end, you have a couple of options to enable FX Graph Mode Quantization only on a part of the model.
16+
You can use any combination of these options:
17+
18+
1. Non traceable code doesn’t need to be quantized
19+
a. Symbolically trace only the code that needs to be quantized
20+
b. Skip symbolic tracing the non-traceable code
21+
22+
2. Non traceable code needs to be quantized
23+
a. Refactor your code to make it symbolically traceable
24+
b. Write your own observed and quantized submodule
25+
26+
"""
27+
28+
####################################################################
29+
# If the code that is not symbolically traceable does not need to be quantized, we have the following two options
30+
# to run FX Graph Mode Quantization:
31+
#
32+
# 1.a. Symbolically trace only the code that needs to be quantized
33+
# -----------------------------------------------------------------
34+
#
35+
# When the whole model is not symbolically traceable but the submodule we want to quantize is
36+
# symbolically traceable, we can run quantization only on that submodule.
37+
#
38+
#
39+
# before:
40+
#
41+
# .. code:: python
42+
#
43+
# class M(nn.Module):
44+
#
45+
# def forward(self, x):
46+
# x = non_traceable_code_1(x)
47+
# x = traceable_code(x)
48+
# x = non_traceable_code_2(x)
49+
# return x
50+
#
51+
#
52+
# after:
53+
#
54+
# .. code:: python
55+
#
56+
# class FP32Traceable(nn.Module):
57+
#
58+
# def forward(self, x):
59+
# x = traceable_code(x)
60+
# return x
61+
#
62+
# class M(nn.Module):
63+
#
64+
# def __init__(self):
65+
# self.traceable_submodule = FP32Traceable(...)
66+
#
67+
# def forward(self, x):
68+
# x = self.traceable_code_1(x)
69+
# # We'll only symbolic trace/quantize this submodule
70+
# x = self.traceable_submodule(x)
71+
# x = self.traceable_code_2(x)
72+
# return x
73+
#
74+
#
75+
# quantization code:
76+
#
77+
# .. code:: python
78+
#
79+
# qconfig_dict = {"": qconfig}
80+
# model_fp32.traceable_submodule = \
81+
# prepare_fx(model_fp32.traceable_submodule, qconfig_dict)
82+
#
83+
# Note if original model needs to be preserved, you will have to
84+
# copy it yourself before calling the quantization APIs.
85+
#
86+
87+
#####################################################
88+
# 1.b. Skip symbolically trace the non-traceable code
89+
# ---------------------------------------------------
90+
# When we have some non-traceable code in the module, and this part of code doesn’t need to be quantized,
91+
# we can factor out this part of the code into a submodule and skip symbolically trace that submodule.
92+
#
93+
#
94+
# before
95+
#
96+
# .. code:: python
97+
#
98+
# class M(nn.Module):
99+
#
100+
# def forward(self, x):
101+
# x = self.traceable_code_1(x)
102+
# x = non_traceable_code(x)
103+
# x = self.traceable_code_2(x)
104+
# return x
105+
#
106+
#
107+
# after, non-traceable parts moved to a module and marked as a leaf
108+
#
109+
# .. code:: python
110+
#
111+
# class FP32NonTraceable(nn.Module):
112+
#
113+
# def forward(self, x):
114+
# x = non_traceable_code(x)
115+
# return x
116+
#
117+
# class M(nn.Module):
118+
#
119+
# def __init__(self):
120+
# ...
121+
# self.non_traceable_submodule = FP32NonTraceable(...)
122+
#
123+
# def forward(self, x):
124+
# x = self.traceable_code_1(x)
125+
# # we will configure the quantization call to not trace through
126+
# # this submodule
127+
# x = self.non_traceable_submodule(x)
128+
# x = self.traceable_code_2(x)
129+
# return x
130+
#
131+
# quantization code:
132+
#
133+
# .. code:: python
134+
#
135+
# qconfig_dict = {"": qconfig}
136+
#
137+
# prepare_custom_config_dict = {
138+
# # option 1
139+
# "non_traceable_module_name": "non_traceable_submodule",
140+
# # option 2
141+
# "non_traceable_module_class": [MNonTraceable],
142+
# }
143+
# model_prepared = prepare_fx(
144+
# model_fp32,
145+
# qconfig_dict,
146+
# prepare_custom_config_dict=prepare_custom_config_dict,
147+
# )
148+
#
149+
# If the code that is not symbolically traceable needs to be quantized, we have the following two options:
150+
151+
##########################################################
152+
# 2.a Refactor your code to make it symbolically traceable
153+
# --------------------------------------------------------
154+
# If it is easy to refactor the code and make the code symbolically traceable,
155+
# we can refactor the code and remove the use of non-traceable constructs in python.
156+
#
157+
# More information about symbolic tracing support can be found in: (TODO: link)
158+
#
159+
# before:
160+
#
161+
# .. code:: python
162+
#
163+
# def transpose_for_scores(self, x):
164+
# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
165+
# x = x.view(*new_x_shape)
166+
# return x.permute(0, 2, 1, 3)
167+
#
168+
#
169+
# This is not symbolically traceable because in x.view(*new_x_shape)
170+
# unpacking is not supported, however, it is easy to remove the unpacking
171+
# since x.view also supports list input.
172+
#
173+
#
174+
# after:
175+
#
176+
# .. code:: python
177+
#
178+
# def transpose_for_scores(self, x):
179+
# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
180+
# x = x.view(new_x_shape)
181+
# return x.permute(0, 2, 1, 3)
182+
#
183+
#
184+
# quantization code:
185+
#
186+
# This can be combined with other approaches and the quantization code
187+
# depends on the model.
188+
#
189+
#
190+
191+
#######################################################
192+
# 2.b. Write your own observed and quantized submodule
193+
# -----------------------------------------------------
194+
#
195+
# If the non-traceable code can’t be refactored to be symbolically traceable,
196+
# for example it has some loops that can’t be eliminated, like nn.LSTM,
197+
# we’ll need to factor out the non-traceable code to a submodule (we call it CustomModule in fx graph mode quantization) and
198+
# define the observed and quantized version of the submodule (in post training static quantization or quantization aware training for static quantization)
199+
# or define the quantized version (in post training dynamic and weight only quantization)
200+
#
201+
#
202+
# before:
203+
#
204+
# .. code:: python
205+
#
206+
# class M(nn.Module):
207+
#
208+
# def forward(self, x):
209+
# x = traceable_code_1(x)
210+
# x = non_traceable_code(x)
211+
# x = traceable_code_1(x)
212+
# return x
213+
#
214+
# after:
215+
#
216+
# 1. Factor out non_traceable_code to FP32NonTraceable
217+
# non-traceable logic, wrapped in a module
218+
#
219+
# .. code:: python
220+
#
221+
# class FP32NonTraceable:
222+
# ...
223+
#
224+
#
225+
# 2. Define observed version of FP32NonTraceable
226+
#
227+
# .. code:: python
228+
#
229+
# class ObservedNonTraceable:
230+
#
231+
# @classmethod
232+
# def from_float(cls, ...):
233+
# ...
234+
#
235+
# 3. Define statically quantized version of FP32NonTraceable
236+
# and a class method "from_observed" to convert from ObservedNonTraceable
237+
# to StaticQuantNonTraceable
238+
#
239+
# .. code:: python
240+
#
241+
# class StaticQuantNonTraceable:
242+
#
243+
# @classmethod
244+
# def from_observed(cls, ...):
245+
# ...
246+
#
247+
#
248+
# .. code:: python
249+
#
250+
# # refactor parent class to call FP32NonTraceable
251+
# class M(nn.Module):
252+
#
253+
# def __init__(self):
254+
# ...
255+
# self.non_traceable_submodule = FP32NonTraceable(...)
256+
#
257+
# def forward(self, x):
258+
# x = self.traceable_code_1(x)
259+
# # this part will be quantized manually
260+
# x = self.non_traceable_submodule(x)
261+
# x = self.traceable_code_1(x)
262+
# return x
263+
#
264+
#
265+
# quantization code:
266+
#
267+
#
268+
# .. code:: python
269+
#
270+
# # post training static quantization or
271+
# # quantization aware training (that produces a statically quantized module)v
272+
# prepare_custom_config_dict = {
273+
# "float_to_observed_custom_module_class": {
274+
# "static": {
275+
# FP32NonTraceable: ObservedNonTraceable,
276+
# }
277+
# },
278+
# }
279+
#
280+
# model_prepared = prepare_fx(
281+
# model_fp32,
282+
# qconfig_dict,
283+
# prepare_custom_config_dict=prepare_custom_config_dict)
284+
#
285+
# calibrate / train (not shown)
286+
#
287+
# .. code:: python
288+
#
289+
# convert_custom_config_dict = {
290+
# "observed_to_quantized_custom_module_class": {
291+
# "static": {
292+
# ObservedNonTraceable: StaticQuantNonTraceable,
293+
# }
294+
# },
295+
# }
296+
# model_quantized = convert_fx(
297+
# model_prepared,
298+
# convert_custom_config_dict)
299+
#
300+
# post training dynamic/weight only quantization
301+
# in these two modes we don't need to observe the original model, so we
302+
# only need to define thee quantized model
303+
#
304+
# .. code:: python
305+
#
306+
# class DynamicQuantNonTraceable: # or WeightOnlyQuantMNonTraceable
307+
# ...
308+
# @classmethod
309+
# def from_observed(cls, ...):
310+
# ...
311+
#
312+
# prepare_custom_config_dict = {
313+
# "non_traceable_module_class": [
314+
# FP32NonTraceable
315+
# ]
316+
# }
317+
#
318+
#
319+
# .. code:: python
320+
#
321+
# # The example is for post training quantization
322+
# model_fp32.eval()
323+
# model_prepared = prepare_fx(
324+
# model_fp32,
325+
# qconfig_dict,
326+
# prepare_custom_config_dict=prepare_custom_config_dict)
327+
#
328+
# convert_custom_config_dict = {
329+
# "observed_to_quantized_custom_module_class": {
330+
# "dynamic": {
331+
# FP32NonTraceable: DynamicQuantNonTraceable,
332+
# }
333+
# },
334+
# }
335+
# model_quantized = convert_fx(
336+
# model_prepared,
337+
# convert_custom_config_dict)
338+
#
339+
# You can also find examples for custom modules in test ``test_custom_module_class`` in ``torch/test/quantization/test_quantize_fx.py``.

0 commit comments

Comments
 (0)