qconfig.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. from collections import namedtuple
  2. from typing import Optional, Any
  3. import torch
  4. import torch.nn as nn
  5. from torch.ao.quantization.fake_quantize import (
  6. FakeQuantize,
  7. FakeQuantizeBase,
  8. default_fake_quant,
  9. default_dynamic_fake_quant,
  10. default_per_channel_weight_fake_quant,
  11. default_weight_fake_quant,
  12. default_fused_act_fake_quant,
  13. default_fused_wt_fake_quant,
  14. FusedMovingAvgObsFakeQuantize,
  15. default_fused_per_channel_wt_fake_quant,
  16. default_embedding_fake_quant,
  17. default_embedding_fake_quant_4bit,
  18. fused_wt_fake_quant_range_neg_127_to_127,
  19. fused_per_channel_wt_fake_quant_range_neg_127_to_127,
  20. )
  21. from .observer import (
  22. HistogramObserver,
  23. MovingAverageMinMaxObserver,
  24. NoopObserver,
  25. PlaceholderObserver,
  26. ReuseInputObserver,
  27. default_debug_observer,
  28. default_dynamic_quant_observer,
  29. default_float_qparams_observer,
  30. default_float_qparams_observer_4bit,
  31. default_observer,
  32. default_per_channel_weight_observer,
  33. default_placeholder_observer,
  34. default_weight_observer,
  35. weight_observer_range_neg_127_to_127,
  36. per_channel_weight_observer_range_neg_127_to_127,
  37. default_reuse_input_observer,
  38. )
  39. import warnings
  40. class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
  41. """
  42. Describes how to quantize a layer or a part of the network by providing
  43. settings (observer classes) for activations and weights respectively.
  44. Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
  45. instances on invocation, not the concrete observer instances themselves.
  46. Quantization preparation function will instantiate observers multiple times for each of the layers.
  47. Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
  48. method (that behaves like functools.partial)::
  49. my_qconfig = QConfig(
  50. activation=MinMaxObserver.with_args(dtype=torch.qint8),
  51. weight=default_observer.with_args(dtype=torch.qint8))
  52. """
  53. def __new__(cls, activation, weight):
  54. # catch common mistakes
  55. if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
  56. raise ValueError("QConfig received observer instance, please pass observer class instead. " +
  57. "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
  58. return super(QConfig, cls).__new__(cls, activation, weight)
  59. class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])):
  60. """
  61. Describes how to dynamically quantize a layer or a part of the network by providing
  62. settings (observer classes) for weights.
  63. It's like QConfig, but for dynamic quantization.
  64. Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
  65. instances on invocation, not the concrete observer instances themselves.
  66. Quantization function will instantiate observers multiple times for each of the layers.
  67. Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
  68. method (that behaves like functools.partial)::
  69. my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
  70. """
  71. def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
  72. # catch common mistakes
  73. if isinstance(weight, nn.Module):
  74. raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
  75. "Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
  76. warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead")
  77. return super(QConfigDynamic, cls).__new__(cls, activation, weight)
  78. default_qconfig = QConfig(activation=default_observer,
  79. weight=default_weight_observer)
  80. """
  81. Default qconfig configuration.
  82. """
  83. default_debug_qconfig = QConfig(weight=default_weight_observer,
  84. activation=default_debug_observer)
  85. """
  86. Default qconfig configuration for debugging.
  87. """
  88. default_per_channel_qconfig = QConfig(activation=default_observer,
  89. weight=default_per_channel_weight_observer)
  90. """
  91. Default qconfig configuration for per channel weight quantization.
  92. """
  93. default_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
  94. weight=default_weight_observer)
  95. """
  96. Default dynamic qconfig.
  97. """
  98. float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float32, compute_dtype=torch.float16),
  99. weight=PlaceholderObserver.with_args(dtype=torch.float16))
  100. """
  101. Dynamic qconfig with weights quantized to `torch.float16`.
  102. """
  103. float16_static_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float16),
  104. weight=PlaceholderObserver.with_args(dtype=torch.float16))
  105. """
  106. Dynamic qconfig with both activations and weights quantized to `torch.float16`.
  107. """
  108. per_channel_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
  109. weight=default_per_channel_weight_observer)
  110. """
  111. Dynamic qconfig with weights quantized per channel.
  112. """
  113. float_qparams_weight_only_qconfig = QConfig(
  114. activation=default_placeholder_observer,
  115. weight=default_float_qparams_observer)
  116. """
  117. Dynamic qconfig with weights quantized with a floating point zero_point.
  118. """
  119. float_qparams_weight_only_qconfig_4bit = QConfig(
  120. activation=default_placeholder_observer,
  121. weight=default_float_qparams_observer_4bit)
  122. default_qat_qconfig = QConfig(activation=default_fake_quant,
  123. weight=default_weight_fake_quant)
  124. """
  125. Default qconfig for QAT.
  126. """
  127. default_dynamic_qat_qconfig = QConfig(activation=default_dynamic_fake_quant,
  128. weight=default_weight_fake_quant)
  129. """
  130. Default qconfig for dynamic QAT.
  131. """
  132. default_weight_only_qconfig = QConfig(activation=torch.nn.Identity,
  133. weight=default_weight_fake_quant)
  134. """
  135. Default qconfig for quantizing weights only.
  136. """
  137. default_activation_only_qconfig = QConfig(activation=default_fake_quant,
  138. weight=torch.nn.Identity)
  139. """
  140. Default qconfig for quantizing activations only.
  141. """
  142. # QAT config that uses a fused observer + fake quant modules for optimized training performance.
  143. # to modify the activation/weight observers, the default entries in fake_quantize.py can be modified.
  144. default_qat_qconfig_v2 = QConfig(activation=default_fused_act_fake_quant, weight=default_fused_wt_fake_quant)
  145. """
  146. Fused version of `default_qat_config`, has performance benefits.
  147. """
  148. default_reuse_input_qconfig = QConfig(activation=default_reuse_input_observer,
  149. weight=NoopObserver)
  150. """
  151. Default qconfig for operators that reuse the observers from input Tensor, e.g. reshape
  152. """
  153. def get_default_qconfig(backend='fbgemm', version=0):
  154. """
  155. Returns the default PTQ qconfig for the specified backend.
  156. Args:
  157. * `backend`: a string representing the target backend. Currently supports `fbgemm`,
  158. `qnnpack` and `onednn`.
  159. Return:
  160. qconfig
  161. """
  162. if version == 0:
  163. if backend == 'fbgemm':
  164. qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
  165. weight=default_per_channel_weight_observer)
  166. elif backend == 'qnnpack':
  167. qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
  168. weight=default_weight_observer)
  169. elif backend == 'onednn':
  170. qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
  171. weight=default_per_channel_weight_observer)
  172. else:
  173. qconfig = default_qconfig
  174. else:
  175. raise AssertionError("Version number: " + str(version) +
  176. " in get_default_qconfig is not supported. Version number must be 0")
  177. return qconfig
  178. """
  179. Default, symmetric PTQ qconfig for the specified backend. And a per_channel
  180. variant of the same.
  181. Symmetric here applies to signed weights with zero point = 0, and additional
  182. value restrictions. The activations are also signed 8-bit integers with this
  183. qconfig.
  184. * Once this change is merged [as of 3/17/22], with backend or qengine =
  185. 'qnnpack', some quantized operators with this symmetric qconfig may use
  186. operators from xnnpack library.
  187. ** Support to use xnnpack ops with `qnnpack` backed for asymmetric
  188. qconfig (returned by get_default_qconfig()) is not available yet.
  189. * This qconfig uses signed activations and weights. Weights have added
  190. restrictions such as zero point is forced to be 0, making the weights
  191. symmetric, hence the name. And the 8-bit quantized values are
  192. restricting to to [-127, +127], excluding -128.
  193. * xnnpack has a requantization scale value restriction, 0x1p-32 <=
  194. requantization_scale < 256.0 where, `requantization_scale = (input_scale
  195. * kernel_scale) / (output_scale)`. Using this eps (w/ assumed max value
  196. of 256) is to prevent requantization_scale to go below xnnpack lower
  197. threshold.
  198. """
  199. default_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
  200. reduce_range=False,
  201. eps=2 ** -12),
  202. weight=weight_observer_range_neg_127_to_127)
  203. default_per_channel_symmetric_qnnpack_qconfig = QConfig(activation=HistogramObserver.with_args(dtype=torch.qint8,
  204. reduce_range=False,
  205. eps=2 ** -12),
  206. weight=per_channel_weight_observer_range_neg_127_to_127)
  207. default_embedding_qat_qconfig = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
  208. weight=default_embedding_fake_quant)
  209. default_embedding_qat_qconfig_4bit = QConfig(activation=NoopObserver.with_args(dtype=torch.float32),
  210. weight=default_embedding_fake_quant_4bit)
  211. def get_default_qat_qconfig(backend='fbgemm', version=1):
  212. """
  213. Returns the default QAT qconfig for the specified backend.
  214. Args:
  215. * `backend`: a string representing the target backend. Currently supports `fbgemm`,
  216. `qnnpack` and `onednn`.
  217. * `version`: version, for backwards compatibility. Can be `None` or `1`.
  218. Return:
  219. qconfig
  220. """
  221. # Histogram observer is too slow for quantization aware training
  222. if version == 0:
  223. if backend == 'fbgemm':
  224. qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  225. quant_min=0,
  226. quant_max=255,
  227. reduce_range=True),
  228. weight=default_per_channel_weight_fake_quant)
  229. elif backend == 'qnnpack':
  230. qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  231. quant_min=0,
  232. quant_max=255,
  233. reduce_range=False),
  234. weight=default_weight_fake_quant)
  235. elif backend == 'onednn':
  236. qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  237. quant_min=0,
  238. quant_max=255),
  239. weight=default_per_channel_weight_fake_quant)
  240. else:
  241. qconfig = default_qat_qconfig
  242. # Use the fused observe + fake_quant modules for doing QAT.
  243. elif version == 1:
  244. if backend == 'fbgemm':
  245. qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  246. quant_min=0,
  247. quant_max=255,
  248. reduce_range=True),
  249. weight=default_fused_per_channel_wt_fake_quant)
  250. elif backend == 'qnnpack':
  251. qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  252. quant_min=0,
  253. quant_max=255,
  254. reduce_range=False),
  255. weight=default_fused_wt_fake_quant)
  256. elif backend == 'onednn':
  257. qconfig = QConfig(activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  258. quant_min=0,
  259. quant_max=255),
  260. weight=default_fused_per_channel_wt_fake_quant)
  261. else:
  262. qconfig = default_qat_qconfig_v2
  263. else:
  264. raise AssertionError("Version number: " + str(version) +
  265. "in get_default_qat_qconfig is not supported. Version number must be 0 or 1")
  266. return qconfig
  267. """
  268. Default symmetric QAT qconfig for qnnpack. And its per channel weight variant.
  269. """
  270. default_symmetric_qnnpack_qat_qconfig = QConfig(
  271. activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  272. quant_min=-128,
  273. quant_max=127,
  274. dtype=torch.qint8,
  275. reduce_range=False,
  276. eps=2 ** -12),
  277. weight=fused_wt_fake_quant_range_neg_127_to_127)
  278. default_per_channel_symmetric_qnnpack_qat_qconfig = QConfig(
  279. activation=FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
  280. quant_min=-128,
  281. quant_max=127,
  282. dtype=torch.qint8,
  283. reduce_range=False,
  284. eps=2 ** -12),
  285. weight=fused_per_channel_wt_fake_quant_range_neg_127_to_127)
  286. def _get_default_qconfig_dict_helper(qconfig, qconfig_transpose):
  287. return {
  288. "": qconfig,
  289. "object_type": [("reshape", default_reuse_input_qconfig),
  290. (torch.nn.Conv1d, qconfig),
  291. (torch.nn.Conv2d, qconfig),
  292. (torch.nn.Conv3d, qconfig),
  293. (torch.nn.ConvTranspose1d, qconfig_transpose),
  294. (torch.nn.ConvTranspose2d, qconfig_transpose),
  295. (torch.nn.ConvTranspose3d, qconfig_transpose),
  296. (torch.nn.Linear, qconfig),
  297. (torch.nn.functional.conv1d, qconfig),
  298. (torch.nn.functional.conv2d, qconfig),
  299. (torch.nn.functional.conv3d, qconfig),
  300. (torch.nn.functional.conv_transpose1d, qconfig_transpose),
  301. (torch.nn.functional.conv_transpose2d, qconfig_transpose),
  302. (torch.nn.functional.conv_transpose3d, qconfig_transpose),
  303. (torch.nn.functional.linear, qconfig),
  304. (torch.nn.ReLU, qconfig),
  305. (torch.nn.functional.relu, qconfig),
  306. (torch.relu, qconfig),
  307. (torch.nn.BatchNorm1d, qconfig),
  308. (torch.nn.BatchNorm2d, qconfig),
  309. (torch.nn.BatchNorm3d, qconfig)]}
  310. def get_default_qconfig_dict(backend='fbgemm', version=0):
  311. qconfig = get_default_qconfig(backend, version)
  312. qconfig_transpose = qconfig
  313. # default_per_channel_weight_observer is not currently compatible with fbgemm backend
  314. # so we have to modify the weight observer to default_weight_observer or another
  315. # per tensor supported observer.
  316. # see https://github.com/pytorch/pytorch/issues/47535
  317. if backend == "fbgemm":
  318. qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_observer)
  319. return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
  320. def get_default_qat_qconfig_dict(backend='fbgemm', version=1):
  321. qconfig = get_default_qat_qconfig(backend, version)
  322. qconfig_transpose = qconfig
  323. # default_per_channel_weight_observer is not currently compatible with fbgemm backend
  324. # so we have to modify the weight observer to default_weight_observer or another
  325. # per tensor supported observer
  326. # see https://github.com/pytorch/pytorch/issues/47535
  327. if backend == "fbgemm":
  328. qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight_fake_quant)
  329. return _get_default_qconfig_dict_helper(qconfig, qconfig_transpose)
  330. def assert_valid_qconfig(qconfig: Optional[QConfig],
  331. mod: torch.nn.Module) -> None:
  332. """
  333. Verifies that this `qconfig` is valid.
  334. """
  335. if qconfig is None:
  336. return
  337. is_conv_transpose_mod = (
  338. isinstance(mod, torch.nn.ConvTranspose1d) or
  339. isinstance(mod, torch.nn.ConvTranspose2d) or
  340. isinstance(mod, torch.nn.ConvTranspose3d))
  341. if is_conv_transpose_mod:
  342. if qconfig.weight is None:
  343. # for now, we assume that any qconfig for ConvTranspose without a weight is valid
  344. return
  345. example_observer = qconfig.weight()
  346. is_per_channel = (
  347. isinstance(example_observer, torch.ao.quantization.PerChannelMinMaxObserver) or
  348. isinstance(example_observer, torch.ao.quantization.MovingAveragePerChannelMinMaxObserver)
  349. )
  350. assert not is_per_channel, \
  351. 'Per channel weight observer is not supported yet for ConvTranspose{n}d.'
  352. # TODO: remove QConfigAny and replace it with Optional[QConfig]
  353. QConfigAny = Optional[QConfig]
  354. def add_module_to_qconfig_obs_ctr(
  355. qconfig: QConfigAny,
  356. module: Optional[nn.Module]) -> Any:
  357. r"""This is a helper function for use in quantization prepare that updates a qconfig so that
  358. the constructors stored in the qconfig will create observers on the same device that
  359. 'module' is on. This is intended to be used when the qconfigs are propagated to each
  360. module in order to avoid potential device alignment issues.
  361. Args:
  362. qconfig: QConfig with obs constructors stored in activation and weight
  363. module: module which the qconfig is related to
  364. Return:
  365. qconfig: configured so that obs constructors set to construct on the same device as module
  366. """
  367. if module is None or qconfig is None or qconfig._fields != ('activation', 'weight'):
  368. return qconfig
  369. def get_factory_kwargs_based_on_module_device():
  370. assert isinstance(module, torch.nn.Module)
  371. devices = {p.device for p in module.parameters()} | \
  372. {p.device for p in module.buffers()}
  373. device = next(iter(devices)) if len(devices) > 0 else None
  374. return None if device is None else {'device': device}
  375. def configure_constructor_to_put_obs_on_module_device(original_constructor):
  376. try:
  377. # check if constructor can accept factory_kwargs
  378. check = original_constructor.with_args(factory_kwargs=None)
  379. check()
  380. return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device)
  381. except AttributeError: # qconfig doesn't have activation or weight
  382. return original_constructor
  383. except TypeError: # the class doesn't accept factory_kwargs argument
  384. return original_constructor
  385. activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation)
  386. weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight)
  387. return QConfig(activation, weight)
  388. def qconfig_equals(q1: QConfigAny, q2: QConfigAny):
  389. """
  390. Returns `True` if `q1` equals `q2`, and `False` otherwise.
  391. """
  392. # functools.partial has no __eq__ operator defined so '==' defaults to 'is'
  393. def partial_equals(p1, p2):
  394. same = p1.func == p2.func
  395. same = same and p1.args == p2.args
  396. return same and p1.keywords == p2.keywords
  397. if q1 is None or q2 is None:
  398. return q1 == q2
  399. else:
  400. assert q1 is not None and q2 is not None
  401. try:
  402. # Qconfig weight and activation can be either a partial wrapper,
  403. # or an observer class. Special handling is required (above) for
  404. # comparing partial wrappers.
  405. if(isinstance(q1.activation, torch.ao.quantization.observer._PartialWrapper)):
  406. activation_same = partial_equals(q1.activation.p, q2.activation.p)
  407. else:
  408. activation_same = q1.activation == q2.activation
  409. if(isinstance(q1.weight, torch.ao.quantization.observer._PartialWrapper)):
  410. weight_same = partial_equals(q1.weight.p, q2.weight.p)
  411. else:
  412. weight_same = q1.weight == q2.weight
  413. return activation_same and weight_same
  414. except AttributeError:
  415. return q1 == q2
  416. def activation_is_memoryless(qconfig: QConfig):
  417. """
  418. Return whether the observer for activations defined in the given QConfig is memoryless.
  419. This means a MovingAverage observer with averaging constant equal to 1.
  420. """
  421. def _is_memoryless(observer):
  422. return hasattr(observer, "averaging_constant") and observer.averaging_constant == 1
  423. act = qconfig.activation()
  424. if isinstance(act, FakeQuantizeBase) and hasattr(act, "activation_post_process"):
  425. return _is_memoryless(act.activation_post_process)
  426. else:
  427. return _is_memoryless(act)
  428. def is_reuse_input_qconfig(qconfig: Optional[QConfig]):
  429. return qconfig is not None and \
  430. isinstance(qconfig.activation(), ReuseInputObserver) and \
  431. isinstance(qconfig.weight(), NoopObserver)