|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +(prototype) FX Graph Mode Quantization User Guide |
| 4 | +=========================================================== |
| 5 | +
|
| 6 | +**Author**: `Jerry Zhang <https://github.com/jerryzh168>`_ |
| 7 | +
|
| 8 | +FX Graph Mode Quantization requires a symbolically traceable model. |
| 9 | +We use the FX framework (TODO: link) to convert a symbolically traceable nn.Module instance to IR, |
| 10 | +and we operate on the IR to execute the quantization passes. |
| 11 | +Please post your question about symbolically tracing your model in `PyTorch Discussion Forum <https://discuss.pytorch.org/c/quantization/17>`_ |
| 12 | +
|
| 13 | +Quantization will only work on the symbolically traceable parts of your model. |
| 14 | +Data dependent control flow (if statements / for loops etc using symbolically traced values) are one common pattern which is not supported. |
| 15 | +If your model is not symbolically traceable end to end, you have a couple of options to enable FX Graph Mode Quantization only on a part of the model. |
| 16 | +You can use any combination of these options: |
| 17 | +
|
| 18 | +1. Non traceable code doesn’t need to be quantized |
| 19 | + a. Symbolically trace only the code that needs to be quantized |
| 20 | + b. Skip symbolic tracing the non-traceable code |
| 21 | +
|
| 22 | +2. Non traceable code needs to be quantized |
| 23 | + a. Refactor your code to make it symbolically traceable |
| 24 | + b. Write your own observed and quantized submodule |
| 25 | +
|
| 26 | +""" |
| 27 | + |
| 28 | +#################################################################### |
| 29 | +# If the code that is not symbolically traceable does not need to be quantized, we have the following two options |
| 30 | +# to run FX Graph Mode Quantization: |
| 31 | +# |
| 32 | +# 1.a. Symbolically trace only the code that needs to be quantized |
| 33 | +# ----------------------------------------------------------------- |
| 34 | +# |
| 35 | +# When the whole model is not symbolically traceable but the submodule we want to quantize is |
| 36 | +# symbolically traceable, we can run quantization only on that submodule. |
| 37 | +# |
| 38 | +# |
| 39 | +# before: |
| 40 | +# |
| 41 | +# .. code:: python |
| 42 | +# |
| 43 | +# class M(nn.Module): |
| 44 | +# |
| 45 | +# def forward(self, x): |
| 46 | +# x = non_traceable_code_1(x) |
| 47 | +# x = traceable_code(x) |
| 48 | +# x = non_traceable_code_2(x) |
| 49 | +# return x |
| 50 | +# |
| 51 | +# |
| 52 | +# after: |
| 53 | +# |
| 54 | +# .. code:: python |
| 55 | +# |
| 56 | +# class FP32Traceable(nn.Module): |
| 57 | +# |
| 58 | +# def forward(self, x): |
| 59 | +# x = traceable_code(x) |
| 60 | +# return x |
| 61 | +# |
| 62 | +# class M(nn.Module): |
| 63 | +# |
| 64 | +# def __init__(self): |
| 65 | +# self.traceable_submodule = FP32Traceable(...) |
| 66 | +# |
| 67 | +# def forward(self, x): |
| 68 | +# x = self.traceable_code_1(x) |
| 69 | +# # We'll only symbolic trace/quantize this submodule |
| 70 | +# x = self.traceable_submodule(x) |
| 71 | +# x = self.traceable_code_2(x) |
| 72 | +# return x |
| 73 | +# |
| 74 | +# |
| 75 | +# quantization code: |
| 76 | +# |
| 77 | +# .. code:: python |
| 78 | +# |
| 79 | +# qconfig_dict = {"": qconfig} |
| 80 | +# model_fp32.traceable_submodule = \ |
| 81 | +# prepare_fx(model_fp32.traceable_submodule, qconfig_dict) |
| 82 | +# |
| 83 | +# Note if original model needs to be preserved, you will have to |
| 84 | +# copy it yourself before calling the quantization APIs. |
| 85 | +# |
| 86 | + |
| 87 | +##################################################### |
| 88 | +# 1.b. Skip symbolically trace the non-traceable code |
| 89 | +# --------------------------------------------------- |
| 90 | +# When we have some non-traceable code in the module, and this part of code doesn’t need to be quantized, |
| 91 | +# we can factor out this part of the code into a submodule and skip symbolically trace that submodule. |
| 92 | +# |
| 93 | +# |
| 94 | +# before |
| 95 | +# |
| 96 | +# .. code:: python |
| 97 | +# |
| 98 | +# class M(nn.Module): |
| 99 | +# |
| 100 | +# def forward(self, x): |
| 101 | +# x = self.traceable_code_1(x) |
| 102 | +# x = non_traceable_code(x) |
| 103 | +# x = self.traceable_code_2(x) |
| 104 | +# return x |
| 105 | +# |
| 106 | +# |
| 107 | +# after, non-traceable parts moved to a module and marked as a leaf |
| 108 | +# |
| 109 | +# .. code:: python |
| 110 | +# |
| 111 | +# class FP32NonTraceable(nn.Module): |
| 112 | +# |
| 113 | +# def forward(self, x): |
| 114 | +# x = non_traceable_code(x) |
| 115 | +# return x |
| 116 | +# |
| 117 | +# class M(nn.Module): |
| 118 | +# |
| 119 | +# def __init__(self): |
| 120 | +# ... |
| 121 | +# self.non_traceable_submodule = FP32NonTraceable(...) |
| 122 | +# |
| 123 | +# def forward(self, x): |
| 124 | +# x = self.traceable_code_1(x) |
| 125 | +# # we will configure the quantization call to not trace through |
| 126 | +# # this submodule |
| 127 | +# x = self.non_traceable_submodule(x) |
| 128 | +# x = self.traceable_code_2(x) |
| 129 | +# return x |
| 130 | +# |
| 131 | +# quantization code: |
| 132 | +# |
| 133 | +# .. code:: python |
| 134 | +# |
| 135 | +# qconfig_dict = {"": qconfig} |
| 136 | +# |
| 137 | +# prepare_custom_config_dict = { |
| 138 | +# # option 1 |
| 139 | +# "non_traceable_module_name": "non_traceable_submodule", |
| 140 | +# # option 2 |
| 141 | +# "non_traceable_module_class": [MNonTraceable], |
| 142 | +# } |
| 143 | +# model_prepared = prepare_fx( |
| 144 | +# model_fp32, |
| 145 | +# qconfig_dict, |
| 146 | +# prepare_custom_config_dict=prepare_custom_config_dict, |
| 147 | +# ) |
| 148 | +# |
| 149 | +# If the code that is not symbolically traceable needs to be quantized, we have the following two options: |
| 150 | + |
| 151 | +########################################################## |
| 152 | +# 2.a Refactor your code to make it symbolically traceable |
| 153 | +# -------------------------------------------------------- |
| 154 | +# If it is easy to refactor the code and make the code symbolically traceable, |
| 155 | +# we can refactor the code and remove the use of non-traceable constructs in python. |
| 156 | +# |
| 157 | +# More information about symbolic tracing support can be found in: (TODO: link) |
| 158 | +# |
| 159 | +# before: |
| 160 | +# |
| 161 | +# .. code:: python |
| 162 | +# |
| 163 | +# def transpose_for_scores(self, x): |
| 164 | +# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| 165 | +# x = x.view(*new_x_shape) |
| 166 | +# return x.permute(0, 2, 1, 3) |
| 167 | +# |
| 168 | +# |
| 169 | +# This is not symbolically traceable because in x.view(*new_x_shape) |
| 170 | +# unpacking is not supported, however, it is easy to remove the unpacking |
| 171 | +# since x.view also supports list input. |
| 172 | +# |
| 173 | +# |
| 174 | +# after: |
| 175 | +# |
| 176 | +# .. code:: python |
| 177 | +# |
| 178 | +# def transpose_for_scores(self, x): |
| 179 | +# new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| 180 | +# x = x.view(new_x_shape) |
| 181 | +# return x.permute(0, 2, 1, 3) |
| 182 | +# |
| 183 | +# |
| 184 | +# quantization code: |
| 185 | +# |
| 186 | +# This can be combined with other approaches and the quantization code |
| 187 | +# depends on the model. |
| 188 | +# |
| 189 | +# |
| 190 | + |
| 191 | +####################################################### |
| 192 | +# 2.b. Write your own observed and quantized submodule |
| 193 | +# ----------------------------------------------------- |
| 194 | +# |
| 195 | +# If the non-traceable code can’t be refactored to be symbolically traceable, |
| 196 | +# for example it has some loops that can’t be eliminated, like nn.LSTM, |
| 197 | +# we’ll need to factor out the non-traceable code to a submodule (we call it CustomModule in fx graph mode quantization) and |
| 198 | +# define the observed and quantized version of the submodule (in post training static quantization or quantization aware training for static quantization) |
| 199 | +# or define the quantized version (in post training dynamic and weight only quantization) |
| 200 | +# |
| 201 | +# |
| 202 | +# before: |
| 203 | +# |
| 204 | +# .. code:: python |
| 205 | +# |
| 206 | +# class M(nn.Module): |
| 207 | +# |
| 208 | +# def forward(self, x): |
| 209 | +# x = traceable_code_1(x) |
| 210 | +# x = non_traceable_code(x) |
| 211 | +# x = traceable_code_1(x) |
| 212 | +# return x |
| 213 | +# |
| 214 | +# after: |
| 215 | +# |
| 216 | +# 1. Factor out non_traceable_code to FP32NonTraceable |
| 217 | +# non-traceable logic, wrapped in a module |
| 218 | +# |
| 219 | +# .. code:: python |
| 220 | +# |
| 221 | +# class FP32NonTraceable: |
| 222 | +# ... |
| 223 | +# |
| 224 | +# |
| 225 | +# 2. Define observed version of FP32NonTraceable |
| 226 | +# |
| 227 | +# .. code:: python |
| 228 | +# |
| 229 | +# class ObservedNonTraceable: |
| 230 | +# |
| 231 | +# @classmethod |
| 232 | +# def from_float(cls, ...): |
| 233 | +# ... |
| 234 | +# |
| 235 | +# 3. Define statically quantized version of FP32NonTraceable |
| 236 | +# and a class method "from_observed" to convert from ObservedNonTraceable |
| 237 | +# to StaticQuantNonTraceable |
| 238 | +# |
| 239 | +# .. code:: python |
| 240 | +# |
| 241 | +# class StaticQuantNonTraceable: |
| 242 | +# |
| 243 | +# @classmethod |
| 244 | +# def from_observed(cls, ...): |
| 245 | +# ... |
| 246 | +# |
| 247 | +# |
| 248 | +# .. code:: python |
| 249 | +# |
| 250 | +# # refactor parent class to call FP32NonTraceable |
| 251 | +# class M(nn.Module): |
| 252 | +# |
| 253 | +# def __init__(self): |
| 254 | +# ... |
| 255 | +# self.non_traceable_submodule = FP32NonTraceable(...) |
| 256 | +# |
| 257 | +# def forward(self, x): |
| 258 | +# x = self.traceable_code_1(x) |
| 259 | +# # this part will be quantized manually |
| 260 | +# x = self.non_traceable_submodule(x) |
| 261 | +# x = self.traceable_code_1(x) |
| 262 | +# return x |
| 263 | +# |
| 264 | +# |
| 265 | +# quantization code: |
| 266 | +# |
| 267 | +# |
| 268 | +# .. code:: python |
| 269 | +# |
| 270 | +# # post training static quantization or |
| 271 | +# # quantization aware training (that produces a statically quantized module)v |
| 272 | +# prepare_custom_config_dict = { |
| 273 | +# "float_to_observed_custom_module_class": { |
| 274 | +# "static": { |
| 275 | +# FP32NonTraceable: ObservedNonTraceable, |
| 276 | +# } |
| 277 | +# }, |
| 278 | +# } |
| 279 | +# |
| 280 | +# model_prepared = prepare_fx( |
| 281 | +# model_fp32, |
| 282 | +# qconfig_dict, |
| 283 | +# prepare_custom_config_dict=prepare_custom_config_dict) |
| 284 | +# |
| 285 | +# calibrate / train (not shown) |
| 286 | +# |
| 287 | +# .. code:: python |
| 288 | +# |
| 289 | +# convert_custom_config_dict = { |
| 290 | +# "observed_to_quantized_custom_module_class": { |
| 291 | +# "static": { |
| 292 | +# ObservedNonTraceable: StaticQuantNonTraceable, |
| 293 | +# } |
| 294 | +# }, |
| 295 | +# } |
| 296 | +# model_quantized = convert_fx( |
| 297 | +# model_prepared, |
| 298 | +# convert_custom_config_dict) |
| 299 | +# |
| 300 | +# post training dynamic/weight only quantization |
| 301 | +# in these two modes we don't need to observe the original model, so we |
| 302 | +# only need to define thee quantized model |
| 303 | +# |
| 304 | +# .. code:: python |
| 305 | +# |
| 306 | +# class DynamicQuantNonTraceable: # or WeightOnlyQuantMNonTraceable |
| 307 | +# ... |
| 308 | +# @classmethod |
| 309 | +# def from_observed(cls, ...): |
| 310 | +# ... |
| 311 | +# |
| 312 | +# prepare_custom_config_dict = { |
| 313 | +# "non_traceable_module_class": [ |
| 314 | +# FP32NonTraceable |
| 315 | +# ] |
| 316 | +# } |
| 317 | +# |
| 318 | +# |
| 319 | +# .. code:: python |
| 320 | +# |
| 321 | +# # The example is for post training quantization |
| 322 | +# model_fp32.eval() |
| 323 | +# model_prepared = prepare_fx( |
| 324 | +# model_fp32, |
| 325 | +# qconfig_dict, |
| 326 | +# prepare_custom_config_dict=prepare_custom_config_dict) |
| 327 | +# |
| 328 | +# convert_custom_config_dict = { |
| 329 | +# "observed_to_quantized_custom_module_class": { |
| 330 | +# "dynamic": { |
| 331 | +# FP32NonTraceable: DynamicQuantNonTraceable, |
| 332 | +# } |
| 333 | +# }, |
| 334 | +# } |
| 335 | +# model_quantized = convert_fx( |
| 336 | +# model_prepared, |
| 337 | +# convert_custom_config_dict) |
| 338 | +# |
| 339 | +# You can also find examples for custom modules in test ``test_custom_module_class`` in ``torch/test/quantization/test_quantize_fx.py``. |
0 commit comments