quantize.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. import copy
  2. import itertools
  3. import warnings
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.quantized as nnq
  7. from torch.nn.intrinsic import _FusedModule
  8. from torch.ao.quantization.quantization_mappings import (
  9. get_default_dynamic_quant_module_mappings,
  10. get_default_static_quant_module_mappings,
  11. get_default_static_quant_reference_module_mappings,
  12. get_default_qat_module_mappings,
  13. get_default_qconfig_propagation_list,
  14. no_observer_set,
  15. _has_special_act_post_process,
  16. _get_special_act_post_process,
  17. )
  18. from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
  19. from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
  20. from torch.ao.quantization.qconfig import (
  21. add_module_to_qconfig_obs_ctr,
  22. default_dynamic_qconfig,
  23. float16_dynamic_qconfig,
  24. float_qparams_weight_only_qconfig,
  25. float_qparams_weight_only_qconfig_4bit,
  26. activation_is_memoryless)
  27. from torch.nn.utils.parametrize import type_before_parametrizations
  28. def is_activation_post_process(module):
  29. return (isinstance(module, torch.ao.quantization.ObserverBase) or
  30. isinstance(module, torch.ao.quantization.FakeQuantizeBase))
  31. def _propagate_qconfig_helper(module, qconfig_dict,
  32. qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
  33. r"""This is a helper function for `propagate_qconfig_`
  34. Args:
  35. module: input module
  36. qconfig_dict: dictionary that maps from name of submodule to quantization
  37. configuration
  38. qconfig_parent: quantization config of parent module, we will fallback to
  39. this config when there is no specified config for current
  40. module
  41. prefix: corresponding prefix of the current module, used as key in
  42. qconfig_dict
  43. prepare_custom_config_dict: dictionary for custom handling of modules
  44. see docs for :func:`~torch.ao.quantization.prepare_fx`
  45. Return:
  46. None, module is modified inplace with qconfig attached
  47. """
  48. module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent)
  49. module_qconfig = qconfig_dict.get(prefix, module_qconfig)
  50. module_qconfig = getattr(module, 'qconfig', module_qconfig)
  51. torch.ao.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)
  52. qconfig_with_device_check = add_module_to_qconfig_obs_ctr(module_qconfig, module)
  53. module.qconfig = qconfig_with_device_check
  54. for name, child in module.named_children():
  55. module_prefix = prefix + '.' + name if prefix else name
  56. # do no not propagate qconfig to child if child is non traceable
  57. if prepare_custom_config_dict is None or not (
  58. name in prepare_custom_config_dict.get("non_traceable_module_name", [])
  59. or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", [])
  60. ):
  61. _propagate_qconfig_helper(
  62. child, qconfig_dict, qconfig_with_device_check, module_prefix
  63. )
  64. def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
  65. r"""Propagate qconfig through the module hierarchy and assign `qconfig`
  66. attribute on each leaf module
  67. Args:
  68. module: input module
  69. qconfig_dict: dictionary that maps from name or type of submodule to
  70. quantization configuration, qconfig applies to all submodules of a
  71. given module unless qconfig for the submodules are specified (when
  72. the submodule already has qconfig attribute)
  73. prepare_custom_config_dict: dictionary for custom handling of modules
  74. see docs for :func:`~torch.ao.quantization.prepare_fx`
  75. Return:
  76. None, module is modified inplace with qconfig attached
  77. """
  78. if qconfig_dict is None:
  79. qconfig_dict = {}
  80. if prepare_custom_config_dict is None:
  81. prepare_custom_config_dict = {}
  82. _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict)
  83. def _observer_forward_hook(self, input, output):
  84. r"""Forward hook that calls observer on the output
  85. """
  86. return self.activation_post_process(output)
  87. def _observer_forward_pre_hook(self, input):
  88. r"""Forward pre hook that calls observer on the output
  89. """
  90. return self.activation_post_process(input[0])
  91. def register_activation_post_process_hook(module, pre_hook=False):
  92. assert hasattr(module, 'activation_post_process'), \
  93. 'Expect activation_post_process attribute already attached to the module'
  94. if pre_hook:
  95. handle = module.register_forward_pre_hook(_observer_forward_pre_hook)
  96. module._forward_pre_hooks.move_to_end(handle.id, last=False)
  97. else:
  98. handle = module.register_forward_hook(_observer_forward_hook)
  99. module._forward_hooks.move_to_end(handle.id, last=False)
  100. def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
  101. r"""Add observer for the leaf child of the module.
  102. This function insert observer module to all leaf child module that
  103. has a valid qconfig attribute.
  104. Args:
  105. module: input module with qconfig attributes for all the leaf modules that we want to quantize
  106. qconfig_propagation_list: a list of quantizable modules that will have observers added to them
  107. if they are leaf nodes
  108. device: parent device, if any
  109. non_leaf_module_list: list of non-leaf modules we want to add observer
  110. Return:
  111. None, module is modified inplace with added observer modules and forward_hooks
  112. """
  113. if qconfig_propagation_list is None:
  114. qconfig_propagation_list = get_default_qconfig_propagation_list()
  115. if custom_module_class_mapping is None:
  116. custom_module_class_mapping = {}
  117. # respect device affinity when adding observers
  118. if device is None:
  119. devices = get_unique_devices_(module)
  120. assert len(devices) <= 1, (
  121. "add_observer_ only works with cpu or single-device CUDA modules, "
  122. "but got devices {}".format(devices)
  123. )
  124. device = next(iter(devices)) if len(devices) > 0 else None
  125. def get_activation_post_process(qconfig, device, special_act_post_process=None):
  126. activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
  127. if device is not None:
  128. activation.to(device)
  129. return activation
  130. def needs_observation(m):
  131. return hasattr(m, 'qconfig') and m.qconfig is not None
  132. def insert_activation_post_process(m, special_act_post_process=None):
  133. """ Adds an activation post process module and register
  134. a pre or post hook that calls the module
  135. """
  136. # We don't insert observer/fake_quantize for DeQuantStub
  137. if needs_observation(m) and not isinstance(m, DeQuantStub):
  138. # observer and hook will be gone after we swap the module
  139. m.add_module('activation_post_process', get_activation_post_process(
  140. m.qconfig, device, special_act_post_process))
  141. # Register observer as the first entry in the hook list
  142. # All post forward hooks are preserved and will be executed after the observer before convert
  143. register_activation_post_process_hook(m, pre_hook=activation_is_memoryless(m.qconfig))
  144. for name, child in module.named_children():
  145. # TODO remove Dropout special after codebase stable
  146. if type_before_parametrizations(child) in [nn.Dropout]:
  147. continue
  148. elif type_before_parametrizations(child) in [nnq.FloatFunctional, nnq.QFunctional]:
  149. if needs_observation(child):
  150. child.activation_post_process = get_activation_post_process(child.qconfig, device)
  151. elif isinstance(child, _FusedModule):
  152. # activation_post_process are now added directly to nn.Sequentail/_FusedModule
  153. if needs_observation(child):
  154. insert_activation_post_process(child)
  155. elif _has_special_act_post_process(child):
  156. special_act_post_process = _get_special_act_post_process(child)
  157. insert_activation_post_process(child, special_act_post_process)
  158. elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list:
  159. if needs_observation(child):
  160. insert_activation_post_process(child)
  161. elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping:
  162. observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child)
  163. setattr(module, name, observed_child)
  164. # TODO: These are the modules that cannot be observed
  165. # Once there are more, we should move them to a separate list
  166. if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
  167. insert_activation_post_process(observed_child)
  168. else:
  169. add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
  170. # Insert observers only for leaf nodes, note that this observer is for
  171. # the output of the module, for input QuantStub will observe them
  172. if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
  173. and type_before_parametrizations(module) in qconfig_propagation_list:
  174. insert_activation_post_process(module)
  175. def get_unique_devices_(module):
  176. return {p.device for p in module.parameters()} | \
  177. {p.device for p in module.buffers()}
  178. def add_quant_dequant(module):
  179. r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
  180. Note that this function will modify the children of module inplace and it
  181. can return a new module which wraps the input module as well.
  182. Args:
  183. module: input module with qconfig attributes for all the leaf modules
  184. that we want to quantize
  185. Return:
  186. Either the inplace modified module with submodules wrapped in
  187. `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
  188. wraps the input module, the latter case only happens when the input
  189. module is a leaf module and we want to quantize it.
  190. """
  191. if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig:
  192. return QuantWrapper(module)
  193. for name, child in module.named_children():
  194. module._modules[name] = add_quant_dequant(child)
  195. return module
  196. def prepare(model, inplace=False, allow_list=None,
  197. observer_non_leaf_module_list=None,
  198. prepare_custom_config_dict=None):
  199. r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
  200. Quantization configuration should be assigned preemptively
  201. to individual submodules in `.qconfig` attribute.
  202. The model will be attached with observer or fake quant modules, and qconfig
  203. will be propagated.
  204. Args:
  205. `model`: input model to be modified in-place
  206. `inplace`: carry out model transformations in-place, the original module is mutated
  207. `allow_list`: list of quantizable modules
  208. `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
  209. `prepare_custom_config_dict`: customization configuration dictionary for prepare function
  210. .. code-block:: python
  211. # Example of prepare_custom_config_dict:
  212. prepare_custom_config_dict = {
  213. # user will manually define the corresponding observed
  214. # module class which has a from_float class method that converts
  215. # float custom module to observed custom module
  216. "float_to_observed_custom_module_class": {
  217. CustomModule: ObservedCustomModule
  218. }
  219. }
  220. """
  221. torch._C._log_api_usage_once("quantization_api.quantize.prepare")
  222. if prepare_custom_config_dict is None:
  223. prepare_custom_config_dict = {}
  224. custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
  225. if not inplace:
  226. model = copy.deepcopy(model)
  227. # TODO: remove allow_list
  228. qconfig_propagation_list = allow_list
  229. if allow_list is None:
  230. qconfig_propagation_list = get_default_qconfig_propagation_list()
  231. propagate_qconfig_(model, qconfig_dict=None)
  232. # sanity check common API misusage
  233. if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
  234. warnings.warn("None of the submodule got qconfig applied. Make sure you "
  235. "passed correct configuration through `qconfig_dict` or "
  236. "by assigning the `.qconfig` attribute directly on submodules")
  237. add_observer_(
  238. model, qconfig_propagation_list, observer_non_leaf_module_list,
  239. custom_module_class_mapping=custom_module_class_mapping)
  240. return model
  241. def _remove_activation_post_process(module):
  242. # TODO: maybe we should change activation_post_process to _activation_post_process
  243. # to prevent it from being used by user
  244. if hasattr(module, 'activation_post_process') and \
  245. is_activation_post_process(module.activation_post_process):
  246. delattr(module, 'activation_post_process')
  247. # remove activation_post_proceess pre and post hooks
  248. def remove_hooks(pre_hook=False):
  249. hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
  250. observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook
  251. handle_ids_to_remove = set()
  252. for handle_id, hook_fn in hook_map.items():
  253. if hook_fn is observer_hook:
  254. handle_ids_to_remove.add(handle_id)
  255. for handle_id in handle_ids_to_remove:
  256. hook_map.pop(handle_id)
  257. remove_hooks(pre_hook=True)
  258. remove_hooks(pre_hook=False)
  259. # TODO: rename to something more general
  260. def _remove_qconfig(module):
  261. r"""Clean up the qconfig left in the module so that new qconfig can be
  262. propagated.
  263. Args:
  264. module: module to be cleaned up
  265. """
  266. for child in module.children():
  267. _remove_qconfig(child)
  268. if hasattr(module, "qconfig"):
  269. del module.qconfig
  270. _remove_activation_post_process(module)
  271. def quantize(model, run_fn, run_args, mapping=None, inplace=False):
  272. r"""Quantize the input float model with post training static quantization.
  273. First it will prepare the model for calibration, then it calls
  274. `run_fn` which will run the calibration step, after that we will
  275. convert the model to a quantized model.
  276. Args:
  277. model: input float model
  278. run_fn: a calibration function for calibrating the prepared model
  279. run_args: positional arguments for `run_fn`
  280. inplace: carry out model transformations in-place, the original module is mutated
  281. mapping: correspondence between original module types and quantized counterparts
  282. Return:
  283. Quantized model.
  284. """
  285. torch._C._log_api_usage_once("quantization_api.quantize.quantize")
  286. if mapping is None:
  287. mapping = get_default_static_quant_module_mappings()
  288. if not inplace:
  289. model = copy.deepcopy(model)
  290. model.eval()
  291. prepare(model, inplace=True)
  292. run_fn(model, *run_args)
  293. convert(model, mapping, inplace=True)
  294. return model
  295. def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
  296. mapping=None, inplace=False):
  297. r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
  298. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
  299. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
  300. by default is performed for layers with large weights size - i.e. Linear and RNN variants.
  301. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
  302. If `qconfig` is provided, the `dtype` argument is ignored.
  303. Args:
  304. model: input model
  305. qconfig_spec: Either:
  306. - A dictionary that maps from name or type of submodule to quantization
  307. configuration, qconfig applies to all submodules of a given
  308. module unless qconfig for the submodules are specified (when the
  309. submodule already has qconfig attribute). Entries in the dictionary
  310. need to be QConfig instances.
  311. - A set of types and/or submodule names to apply dynamic quantization to,
  312. in which case the `dtype` argument is used to specify the bit-width
  313. inplace: carry out model transformations in-place, the original module is mutated
  314. mapping: maps type of a submodule to a type of corresponding dynamically quantized version
  315. with which the submodule needs to be replaced
  316. """
  317. torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
  318. if qconfig_spec is None:
  319. if dtype == torch.qint8:
  320. qconfig_spec = {
  321. nn.Linear : default_dynamic_qconfig,
  322. nn.LSTM : default_dynamic_qconfig,
  323. nn.GRU : default_dynamic_qconfig,
  324. nn.LSTMCell : default_dynamic_qconfig,
  325. nn.RNNCell : default_dynamic_qconfig,
  326. nn.GRUCell : default_dynamic_qconfig,
  327. }
  328. elif dtype == torch.float16:
  329. qconfig_spec = {
  330. nn.Linear : float16_dynamic_qconfig,
  331. nn.LSTM : float16_dynamic_qconfig,
  332. nn.GRU : float16_dynamic_qconfig,
  333. nn.LSTMCell : float16_dynamic_qconfig,
  334. nn.RNNCell : float16_dynamic_qconfig,
  335. nn.GRUCell : float16_dynamic_qconfig,
  336. }
  337. elif dtype == torch.quint8:
  338. qconfig_spec = {
  339. nn.EmbeddingBag : float_qparams_weight_only_qconfig,
  340. nn.Embedding : float_qparams_weight_only_qconfig,
  341. }
  342. elif dtype == torch.quint4x2:
  343. qconfig_spec = {
  344. nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit,
  345. }
  346. else:
  347. raise ValueError(
  348. "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
  349. elif isinstance(qconfig_spec, set):
  350. if dtype is torch.qint8:
  351. default_qconfig = default_dynamic_qconfig
  352. elif dtype is torch.float16:
  353. default_qconfig = float16_dynamic_qconfig
  354. elif dtype is torch.quint8:
  355. default_qconfig = float_qparams_weight_only_qconfig
  356. elif dtype is torch.quint4x2:
  357. default_qconfig = float_qparams_weight_only_qconfig_4bit
  358. else:
  359. raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
  360. qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
  361. if mapping is None:
  362. mapping = get_default_dynamic_quant_module_mappings()
  363. if not inplace:
  364. model = copy.deepcopy(model)
  365. model.eval()
  366. propagate_qconfig_(model, qconfig_spec)
  367. convert(model, mapping, inplace=True)
  368. return model
  369. def prepare_qat(model, mapping=None, inplace=False):
  370. r"""
  371. Prepares a copy of the model for quantization calibration or
  372. quantization-aware training and converts it to quantized version.
  373. Quantization configuration should be assigned preemptively
  374. to individual submodules in `.qconfig` attribute.
  375. Args:
  376. model: input model to be modified in-place
  377. mapping: dictionary that maps float modules to quantized modules to be
  378. replaced.
  379. inplace: carry out model transformations in-place, the original module
  380. is mutated
  381. """
  382. torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
  383. assert model.training, "prepare_qat only works on models in training mode"
  384. if mapping is None:
  385. mapping = get_default_qat_module_mappings()
  386. if not inplace:
  387. model = copy.deepcopy(model)
  388. propagate_qconfig_(model, qconfig_dict=None)
  389. convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
  390. prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
  391. return model
  392. def quantize_qat(model, run_fn, run_args, inplace=False):
  393. r"""Do quantization aware training and output a quantized model
  394. Args:
  395. model: input model
  396. run_fn: a function for evaluating the prepared model, can be a
  397. function that simply runs the prepared model or a training
  398. loop
  399. run_args: positional arguments for `run_fn`
  400. Return:
  401. Quantized model.
  402. """
  403. torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
  404. if not inplace:
  405. model = copy.deepcopy(model)
  406. model.train()
  407. prepare_qat(model, inplace=True)
  408. run_fn(model, *run_args)
  409. convert(model, inplace=True)
  410. return model
  411. def convert(
  412. module, mapping=None, inplace=False, remove_qconfig=True,
  413. is_reference=False, convert_custom_config_dict=None):
  414. r"""Converts submodules in input module to a different module according to `mapping`
  415. by calling `from_float` method on the target module class. And remove qconfig at the
  416. end if remove_qconfig is set to True.
  417. Args:
  418. `module`: prepared and calibrated module
  419. `mapping`: a dictionary that maps from source module type to target
  420. module type, can be overwritten to allow swapping user defined
  421. Modules
  422. `inplace`: carry out model transformations in-place, the original module
  423. is mutated
  424. `convert_custom_config_dict`: custom configuration dictionary for convert function
  425. .. code-block:: python
  426. # Example of convert_custom_config_dict:
  427. convert_custom_config_dict = {
  428. # user will manually define the corresponding quantized
  429. # module class which has a from_observed class method that converts
  430. # observed custom module to quantized custom module
  431. "observed_to_quantized_custom_module_class": {
  432. ObservedCustomModule: QuantizedCustomModule
  433. }
  434. }
  435. """
  436. torch._C._log_api_usage_once("quantization_api.quantize.convert")
  437. if not inplace:
  438. module = copy.deepcopy(module)
  439. _convert(
  440. module, mapping, inplace=True, is_reference=is_reference,
  441. convert_custom_config_dict=convert_custom_config_dict)
  442. if remove_qconfig:
  443. _remove_qconfig(module)
  444. return module
  445. def _convert(
  446. module, mapping=None, inplace=False,
  447. is_reference=False, convert_custom_config_dict=None):
  448. r"""Converts submodules in input module to a different module according to `mapping`
  449. by calling `from_float` method on the target module class
  450. Args:
  451. module: input module
  452. mapping: a dictionary that maps from source module type to target
  453. module type, can be overwritten to allow swapping user defined
  454. Modules
  455. inplace: carry out model transformations in-place, the original module
  456. is mutated
  457. is_reference: a flag to enable quantized reference module
  458. """
  459. if mapping is None:
  460. mapping = get_default_static_quant_reference_module_mappings() if is_reference \
  461. else get_default_static_quant_module_mappings()
  462. if convert_custom_config_dict is None:
  463. convert_custom_config_dict = {}
  464. custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
  465. if not inplace:
  466. module = copy.deepcopy(module)
  467. reassign = {}
  468. for name, mod in module.named_children():
  469. # both fused modules and observed custom modules are
  470. # swapped as one unit
  471. if not isinstance(mod, _FusedModule) and \
  472. type_before_parametrizations(mod) not in custom_module_class_mapping:
  473. _convert(mod, mapping, True, # inplace
  474. is_reference, convert_custom_config_dict)
  475. reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
  476. for key, value in reassign.items():
  477. module._modules[key] = value
  478. return module
  479. def swap_module(mod, mapping, custom_module_class_mapping):
  480. r"""Swaps the module if it has a quantized counterpart and it has an
  481. `observer` attached.
  482. Args:
  483. mod: input module
  484. mapping: a dictionary that maps from nn module to nnq module
  485. Return:
  486. The corresponding quantized module of `mod`
  487. """
  488. new_mod = mod
  489. if hasattr(mod, 'qconfig') and mod.qconfig is not None:
  490. swapped = False
  491. if type_before_parametrizations(mod) in custom_module_class_mapping:
  492. new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod)
  493. swapped = True
  494. elif type_before_parametrizations(mod) in mapping:
  495. qmod = mapping[type_before_parametrizations(mod)]
  496. if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE:
  497. assert mod.qconfig is not None
  498. weight_post_process = mod.qconfig.weight()
  499. weight_post_process(mod.weight)
  500. weight_qparams = get_qparam_dict(weight_post_process)
  501. new_mod = qmod.from_float(mod, weight_qparams)
  502. else:
  503. new_mod = qmod.from_float(mod)
  504. swapped = True
  505. if swapped:
  506. # Preserve module's pre forward hooks. They'll be called on quantized input
  507. for pre_hook_fn in mod._forward_pre_hooks.values():
  508. new_mod.register_forward_pre_hook(pre_hook_fn)
  509. # Preserve module's post forward hooks except _observer_forward_hook
  510. # After convert they'll work with quantized output
  511. for hook_fn in mod._forward_hooks.values():
  512. if hook_fn is not _observer_forward_hook:
  513. new_mod.register_forward_hook(hook_fn)
  514. # respect device affinity when swapping modules
  515. devices = get_unique_devices_(mod)
  516. assert len(devices) <= 1, (
  517. "swap_module only works with cpu or single-device CUDA modules, "
  518. "but got devices {}".format(devices)
  519. )
  520. device = next(iter(devices)) if len(devices) > 0 else None
  521. if device:
  522. new_mod.to(device)
  523. return new_mod
  524. def get_observer_dict(mod, target_dict, prefix=""):
  525. r"""Traverse the modules and save all observers into dict.
  526. This is mainly used for quantization accuracy debug
  527. Args:
  528. mod: the top module we want to save all observers
  529. prefix: the prefix for the current module
  530. target_dict: the dictionary used to save all the observers
  531. """
  532. def get_prefix(prefix):
  533. return prefix if prefix == "" else prefix + '.'
  534. if hasattr(mod, 'activation_post_process'):
  535. target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
  536. for name, child in mod.named_children():
  537. module_prefix = get_prefix(prefix) + name if prefix else name
  538. get_observer_dict(child, target_dict, module_prefix)