observer.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514
  1. """
  2. This module implements observers which are used to collect statistics about
  3. the values observed during calibration (PTQ) or training (QAT).
  4. """
  5. import re
  6. import warnings
  7. from abc import ABCMeta, abstractmethod
  8. from collections import OrderedDict
  9. from functools import partial
  10. from typing import Any, List, Tuple, Optional, Dict
  11. import torch
  12. import torch.nn as nn
  13. from torch.ao.quantization.utils import check_min_max_valid, calculate_qmin_qmax
  14. class _PartialWrapper(object):
  15. def __init__(self, p):
  16. self.p = p
  17. self.callable_args = {}
  18. def __call__(self, *args, **keywords):
  19. # call each arg in callable_args and add them partial, then run with keywords
  20. # skip if arg_name in keywords so its possible to overwrite
  21. for arg_name in self.callable_args:
  22. if arg_name not in keywords:
  23. keywords = {**keywords, **{arg_name: self.callable_args[arg_name]()}}
  24. return self.p(*args, **keywords)
  25. def __repr__(self):
  26. return self.p.__repr__() + self.callable_args.__repr__()
  27. def with_args(self, **kwargs):
  28. return _with_args(self, **kwargs)
  29. def with_callable_args(self, **kwargs):
  30. result = _PartialWrapper(p=self.p)
  31. result.callable_args = {**self.callable_args, **kwargs}
  32. return result
  33. def _with_args(cls_or_self, **kwargs):
  34. r"""Wrapper that allows creation of class factories.
  35. This can be useful when there is a need to create classes with the same
  36. constructor arguments, but different instances. Can be used in conjunction with
  37. _callable_args
  38. Example::
  39. >>> Foo.with_args = classmethod(_with_args)
  40. >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
  41. >>> foo_instance1 = foo_builder()
  42. >>> foo_instance2 = foo_builder()
  43. >>> id(foo_instance1) == id(foo_instance2)
  44. False
  45. """
  46. r = _PartialWrapper(partial(cls_or_self, **kwargs))
  47. return r
  48. def _with_callable_args(cls_or_self, **kwargs):
  49. r"""Wrapper that allows creation of class factories args that need to be
  50. called at construction time.
  51. This can be useful when there is a need to create classes with the same
  52. constructor arguments, but different instances and those arguments should only
  53. be calculated at construction time. Can be used in conjunction with _with_args
  54. Example::
  55. >>> Foo.with_callable_args = classmethod(_with_callable_args)
  56. >>> Foo.with_args = classmethod(_with_args)
  57. >>> foo_builder = Foo.with_callable_args(cur_time=get_time_func).with_args(name="dan")
  58. >>> foo_instance1 = foo_builder()
  59. >>> wait 50
  60. >>> foo_instance2 = foo_builder()
  61. >>> id(foo_instance1.creation_time) == id(foo_instance2.creation_time)
  62. False
  63. """
  64. r = _PartialWrapper(partial(cls_or_self))
  65. return r.with_callable_args(**kwargs)
  66. ABC: Any = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
  67. class ObserverBase(ABC, nn.Module):
  68. r"""Base observer Module.
  69. Any observer implementation should derive from this class.
  70. Concrete observers should follow the same API. In forward, they will update
  71. the statistics of the observed Tensor. And they should provide a
  72. `calculate_qparams` function that computes the quantization parameters given
  73. the collected statistics.
  74. Args:
  75. dtype: Quantized data type
  76. """
  77. def __init__(self, dtype):
  78. super(ObserverBase, self).__init__()
  79. self.dtype = dtype
  80. @abstractmethod
  81. def forward(self, x):
  82. pass
  83. @abstractmethod
  84. def calculate_qparams(self, **kwargs):
  85. pass
  86. with_args = classmethod(_with_args)
  87. with_callable_args = classmethod(_with_callable_args)
  88. class UniformQuantizationObserverBase(ObserverBase):
  89. r"""Common base for all observers using uniform quantization to calculate
  90. scale and zero_point.
  91. Args:
  92. dtype: Quantized data type.
  93. qscheme: Quantization scheme to be used.
  94. reduce_range: Reduces the range of the quantized data type by 1 bit.
  95. This is sometimes required to avoid instruction overflow.
  96. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  97. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  98. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  99. .. warning::
  100. :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
  101. .. warning::
  102. :attr:`qscheme` can only take one of the following options:
  103. - ``torch.per_tensor_affine``
  104. - ``torch.per_tensor_symmetric``
  105. - ``torch.per_channel_affine``
  106. - ``torch.per_channel_symmetric``
  107. """
  108. # Note: the version is shared by all observer types
  109. #
  110. # Version 1/None
  111. # self
  112. #
  113. # Version 2 (base class only, does not include child class buffers)
  114. # self
  115. # |--- eps : Tensor
  116. #
  117. # Version 3
  118. # for HistogramObserver only, changed the shape of uninitialized
  119. # min_val and max_val buffers from torch.Size([0]) to torch.Size([])
  120. # for PerChannelObservers, changed the name of the buffers from min_vals
  121. # to min_val and from max_vals to max_val.
  122. _version = 3
  123. eps: torch.Tensor
  124. def __init__(
  125. self,
  126. dtype=torch.quint8,
  127. qscheme=torch.per_tensor_affine,
  128. reduce_range=False,
  129. quant_min=None,
  130. quant_max=None,
  131. factory_kwargs=None,
  132. eps=torch.finfo(torch.float32).eps,
  133. ) -> None:
  134. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  135. super().__init__(dtype=dtype)
  136. self.qscheme = qscheme
  137. if reduce_range:
  138. warnings.warn(
  139. "Please use quant_min and quant_max to specify the range for observers. \
  140. reduce_range will be deprecated in a future release of PyTorch."
  141. )
  142. self.reduce_range = reduce_range
  143. self.register_buffer(
  144. "eps", torch.tensor([eps], **factory_kwargs)
  145. )
  146. assert self.qscheme in (
  147. torch.per_tensor_affine,
  148. torch.per_tensor_symmetric,
  149. torch.per_channel_affine,
  150. torch.per_channel_symmetric,
  151. torch.per_channel_affine_float_qparams,
  152. ), "Default Observer only works for per_tensor_affine, \
  153. per_tensor_symmetric, per_channel_affine, \
  154. per_channel_symmetric and per_channel_float_qparams quantization scheme"
  155. assert self.dtype in (
  156. torch.qint8,
  157. torch.quint8,
  158. torch.quint4x2,
  159. torch.qint32,
  160. ), "Default Observer only works for qint8, quint8 and quint4x2 data type"
  161. self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
  162. if self.has_customized_qrange:
  163. self._validate_qmin_qmax(quant_min, quant_max)
  164. self.quant_min, self.quant_max = \
  165. calculate_qmin_qmax(quant_min, quant_max, self.has_customized_qrange, self.dtype, self.reduce_range)
  166. def _load_from_state_dict(
  167. self,
  168. state_dict,
  169. prefix,
  170. local_metadata,
  171. strict,
  172. missing_keys,
  173. unexpected_keys,
  174. error_msgs,
  175. ):
  176. version = local_metadata.get("version", None)
  177. if version is None or version == 1:
  178. # eps was moved to a buffer in version 2
  179. eps = torch.tensor([torch.finfo(torch.float32).eps])
  180. state_dict[prefix + "eps"] = eps
  181. super(ObserverBase, self)._load_from_state_dict(
  182. state_dict,
  183. prefix,
  184. local_metadata,
  185. strict,
  186. missing_keys,
  187. unexpected_keys,
  188. error_msgs,
  189. )
  190. @torch.jit.export
  191. def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None:
  192. r"""Validates that the user-specified quantization range is properly initialized
  193. and within the given bound supported by the observer dtype.
  194. To accommodate lower-bit quantization with respect to the existing torch.qint8 and
  195. torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
  196. in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
  197. values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
  198. fake quantization. These estimates are compared against parameters learned through backpropagation.
  199. The related literatures for scale and zero point via backpropagation are as follows:
  200. Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
  201. Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
  202. """
  203. # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
  204. # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
  205. assert (
  206. quant_min <= 0 <= quant_max
  207. ), "Used-specified quantization range must include 0."
  208. assert (
  209. quant_min < quant_max
  210. ), "qmin must be strictly less than qmax for user-specified quantization range."
  211. @torch.jit.export
  212. def _calculate_qparams(
  213. self, min_val: torch.Tensor, max_val: torch.Tensor
  214. ) -> Tuple[torch.Tensor, torch.Tensor]:
  215. r"""Calculates the quantization parameters, given min and max
  216. value tensors. Works for both per tensor and per channel cases
  217. Args:
  218. min_val: Minimum values per channel
  219. max_val: Maximum values per channel
  220. Returns:
  221. scales: Scales tensor of shape (#channels,)
  222. zero_points: Zero points tensor of shape (#channels,)
  223. """
  224. if not check_min_max_valid(min_val, max_val):
  225. return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
  226. quant_min, quant_max = self.quant_min, self.quant_max
  227. min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
  228. max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
  229. device = min_val_neg.device
  230. scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
  231. zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  232. if (
  233. self.qscheme == torch.per_tensor_symmetric
  234. or self.qscheme == torch.per_channel_symmetric
  235. ):
  236. max_val_pos = torch.max(-min_val_neg, max_val_pos)
  237. scale = max_val_pos / (float(quant_max - quant_min) / 2)
  238. scale = torch.max(scale, self.eps)
  239. if self.dtype == torch.quint8:
  240. if self.has_customized_qrange:
  241. # When customized quantization range is used, down-rounded midpoint of the range is chosen.
  242. zero_point = zero_point.new_full(
  243. zero_point.size(), (quant_min + quant_max) // 2
  244. )
  245. else:
  246. zero_point = zero_point.new_full(zero_point.size(), 128)
  247. elif self.qscheme == torch.per_channel_affine_float_qparams:
  248. scale = (max_val - min_val) / float(quant_max - quant_min)
  249. scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
  250. # We use the quantize function
  251. # xq = Round(Xf * inv_scale + zero_point),
  252. # setting zero_point to (-1 * min *inv_scale) we get
  253. # Xq = Round((Xf - min) * inv_scale)
  254. zero_point = -1 * min_val / scale
  255. else:
  256. scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
  257. scale = torch.max(scale, self.eps)
  258. zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
  259. zero_point = torch.clamp(zero_point, quant_min, quant_max)
  260. # For scalar values, cast them to Tensors of size 1 to keep the shape
  261. # consistent with default values in FakeQuantize.
  262. if len(scale.shape) == 0:
  263. # TODO: switch to scale.item() after adding JIT support
  264. scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
  265. if len(zero_point.shape) == 0:
  266. # TODO: switch to zero_point.item() after adding JIT support
  267. zero_point = torch.tensor(
  268. [int(zero_point)], dtype=zero_point.dtype, device=device
  269. )
  270. if self.qscheme == torch.per_channel_affine_float_qparams:
  271. zero_point = torch.tensor(
  272. [float(zero_point)], dtype=zero_point.dtype, device=device
  273. )
  274. return scale, zero_point
  275. @torch.jit.export
  276. def reset_min_max_vals(self):
  277. raise NotImplementedError("Cannot reset min/max values in the given observer.")
  278. # Originally, this class was called `_ObserverBase`. Keeping the old name around
  279. # for backwards compatibility.
  280. # TODO(after v1.13): delete this
  281. _ObserverBase = UniformQuantizationObserverBase
  282. class MinMaxObserver(UniformQuantizationObserverBase):
  283. r"""Observer module for computing the quantization parameters based on the
  284. running min and max values.
  285. This observer uses the tensor min/max statistics to compute the quantization
  286. parameters. The module records the running minimum and maximum of incoming
  287. tensors, and uses this statistic to compute the quantization parameters.
  288. Args:
  289. dtype: Quantized data type
  290. qscheme: Quantization scheme to be used
  291. reduce_range: Reduces the range of the quantized data type by 1 bit
  292. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  293. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  294. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  295. Given running min/max as :math:`x_\text{min}` and :math:`x_\text{max}`,
  296. scale :math:`s` and zero point :math:`z` are computed as:
  297. The running minimum/maximum :math:`x_\text{min/max}` is computed as:
  298. .. math::
  299. \begin{array}{ll}
  300. x_\text{min} &= \begin{cases}
  301. \min(X) & \text{if~}x_\text{min} = \text{None} \\
  302. \min\left(x_\text{min}, \min(X)\right) & \text{otherwise}
  303. \end{cases}\\
  304. x_\text{max} &= \begin{cases}
  305. \max(X) & \text{if~}x_\text{max} = \text{None} \\
  306. \max\left(x_\text{max}, \max(X)\right) & \text{otherwise}
  307. \end{cases}\\
  308. \end{array}
  309. where :math:`X` is the observed tensor.
  310. The scale :math:`s` and zero point :math:`z` are then computed as:
  311. .. math::
  312. \begin{aligned}
  313. \text{if Symmetric:}&\\
  314. &s = 2 \max(|x_\text{min}|, x_\text{max}) /
  315. \left( Q_\text{max} - Q_\text{min} \right) \\
  316. &z = \begin{cases}
  317. 0 & \text{if dtype is qint8} \\
  318. 128 & \text{otherwise}
  319. \end{cases}\\
  320. \text{Otherwise:}&\\
  321. &s = \left( x_\text{max} - x_\text{min} \right ) /
  322. \left( Q_\text{max} - Q_\text{min} \right ) \\
  323. &z = Q_\text{min} - \text{round}(x_\text{min} / s)
  324. \end{aligned}
  325. where :math:`Q_\text{min}` and :math:`Q_\text{max}` are the minimum and
  326. maximum of the quantized data type.
  327. .. warning:: :attr:`dtype` can only take ``torch.qint8`` or ``torch.quint8``.
  328. .. note:: If the running minimum equals to the running maximum, the scale
  329. and zero_point are set to 1.0 and 0.
  330. """
  331. min_val: torch.Tensor
  332. max_val: torch.Tensor
  333. def __init__(
  334. self,
  335. dtype=torch.quint8,
  336. qscheme=torch.per_tensor_affine,
  337. reduce_range=False,
  338. quant_min=None,
  339. quant_max=None,
  340. factory_kwargs=None,
  341. eps=torch.finfo(torch.float32).eps,
  342. ) -> None:
  343. # For x86 quantized kernels, we need to ensure that the vpmaddubsw
  344. # instruction does not overflow. We allow for a reduce_range argument to
  345. # observers that reduces the quantized range to (0,127) or (-64, 63).
  346. # For more details see aten/src/ATen/native/quantized/cpu/qconv.cpp
  347. # This is not an optimal choice for non x86 backends as it loses a bit
  348. # of precision for activations.
  349. super(MinMaxObserver, self).__init__(
  350. dtype=dtype,
  351. qscheme=qscheme,
  352. reduce_range=reduce_range,
  353. quant_min=quant_min,
  354. quant_max=quant_max,
  355. factory_kwargs=factory_kwargs,
  356. eps=eps,
  357. )
  358. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  359. self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
  360. self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
  361. if (
  362. self.qscheme == torch.per_tensor_symmetric
  363. and self.reduce_range
  364. and self.dtype == torch.quint8
  365. ):
  366. raise NotImplementedError(
  367. "Cannot reduce range for symmetric \
  368. quantization for quint8"
  369. )
  370. def forward(self, x_orig):
  371. r"""Records the running minimum and maximum of ``x``."""
  372. if x_orig.numel() == 0:
  373. return x_orig
  374. x = x_orig.detach() # avoid keeping autograd tape
  375. x = x.to(self.min_val.dtype)
  376. min_val_cur, max_val_cur = torch.aminmax(x)
  377. min_val = torch.min(min_val_cur, self.min_val)
  378. max_val = torch.max(max_val_cur, self.max_val)
  379. self.min_val.copy_(min_val)
  380. self.max_val.copy_(max_val)
  381. return x_orig
  382. @torch.jit.export
  383. def calculate_qparams(self):
  384. r"""Calculates the quantization parameters."""
  385. return self._calculate_qparams(self.min_val, self.max_val)
  386. @torch.jit.export
  387. def extra_repr(self):
  388. return "min_val={}, max_val={}".format(self.min_val, self.max_val)
  389. @torch.jit.export
  390. def reset_min_max_vals(self):
  391. """Resets the min/max values."""
  392. self.min_val.copy_(torch.tensor(float("inf")))
  393. self.max_val.copy_(torch.tensor(float("-inf")))
  394. class MovingAverageMinMaxObserver(MinMaxObserver):
  395. r"""Observer module for computing the quantization parameters based on the
  396. moving average of the min and max values.
  397. This observer computes the quantization parameters based on the moving
  398. averages of minimums and maximums of the incoming tensors. The module
  399. records the average minimum and maximum of incoming tensors, and uses this
  400. statistic to compute the quantization parameters.
  401. Args:
  402. averaging_constant: Averaging constant for min/max.
  403. dtype: Quantized data type
  404. qscheme: Quantization scheme to be used
  405. reduce_range: Reduces the range of the quantized data type by 1 bit
  406. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  407. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  408. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  409. The moving average min/max is computed as follows
  410. .. math::
  411. \begin{array}{ll}
  412. x_\text{min} = \begin{cases}
  413. \min(X) & \text{if~}x_\text{min} = \text{None} \\
  414. (1 - c) x_\text{min} + c \min(X) & \text{otherwise}
  415. \end{cases}\\
  416. x_\text{max} = \begin{cases}
  417. \max(X) & \text{if~}x_\text{max} = \text{None} \\
  418. (1 - c) x_\text{max} + c \max(X) & \text{otherwise}
  419. \end{cases}\\
  420. \end{array}
  421. where :math:`x_\text{min/max}` is the running average min/max, :math:`X` is
  422. is the incoming tensor, and :math:`c` is the ``averaging_constant``.
  423. The scale and zero point are then computed as in
  424. :class:`~torch.ao.quantization.observer.MinMaxObserver`.
  425. .. note:: Only works with ``torch.per_tensor_affine`` quantization scheme.
  426. .. note:: If the running minimum equals to the running maximum, the scale
  427. and zero_point are set to 1.0 and 0.
  428. """
  429. def __init__(
  430. self,
  431. averaging_constant=0.01,
  432. dtype=torch.quint8,
  433. qscheme=torch.per_tensor_affine,
  434. reduce_range=False,
  435. quant_min=None,
  436. quant_max=None,
  437. eps=torch.finfo(torch.float32).eps,
  438. **kwargs
  439. ) -> None:
  440. self.averaging_constant = averaging_constant
  441. super(MovingAverageMinMaxObserver, self).__init__(
  442. dtype=dtype,
  443. qscheme=qscheme,
  444. reduce_range=reduce_range,
  445. quant_min=quant_min,
  446. quant_max=quant_max,
  447. eps=eps,
  448. **kwargs
  449. )
  450. def forward(self, x_orig):
  451. if x_orig.numel() == 0:
  452. return x_orig
  453. x = x_orig.detach() # avoid keeping autograd tape
  454. x = x.to(self.min_val.dtype)
  455. min_val = self.min_val
  456. max_val = self.max_val
  457. if min_val == float("inf") and max_val == float("-inf"):
  458. min_val, max_val = torch.aminmax(x)
  459. else:
  460. min_val_cur, max_val_cur = torch.aminmax(x)
  461. min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
  462. max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
  463. self.min_val.copy_(min_val)
  464. self.max_val.copy_(max_val)
  465. return x_orig
  466. class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
  467. r"""Observer module for computing the quantization parameters based on the
  468. running per channel min and max values.
  469. This observer uses the tensor min/max statistics to compute the per channel
  470. quantization parameters. The module records the running minimum and maximum
  471. of incoming tensors, and uses this statistic to compute the quantization
  472. parameters.
  473. Args:
  474. ch_axis: Channel axis
  475. dtype: Quantized data type
  476. qscheme: Quantization scheme to be used
  477. reduce_range: Reduces the range of the quantized data type by 1 bit
  478. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  479. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  480. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  481. The quantization parameters are computed the same way as in
  482. :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
  483. that the running min/max values are stored per channel.
  484. Scales and zero points are thus computed per channel as well.
  485. .. note:: If the running minimum equals to the running maximum, the scales
  486. and zero_points are set to 1.0 and 0.
  487. """
  488. min_val: torch.Tensor
  489. max_val: torch.Tensor
  490. def __init__(
  491. self,
  492. ch_axis=0,
  493. dtype=torch.quint8,
  494. qscheme=torch.per_channel_affine,
  495. reduce_range=False,
  496. quant_min=None,
  497. quant_max=None,
  498. factory_kwargs=None,
  499. eps=torch.finfo(torch.float32).eps,
  500. ) -> None:
  501. super(PerChannelMinMaxObserver, self).__init__(
  502. dtype=dtype,
  503. qscheme=qscheme,
  504. reduce_range=reduce_range,
  505. quant_min=quant_min,
  506. quant_max=quant_max,
  507. factory_kwargs=factory_kwargs,
  508. eps=eps,
  509. )
  510. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  511. self.ch_axis = ch_axis
  512. self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
  513. self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
  514. if (
  515. self.qscheme == torch.per_channel_symmetric
  516. and self.reduce_range
  517. and self.dtype == torch.quint8
  518. ):
  519. raise NotImplementedError(
  520. "Cannot reduce range for symmetric quantization for quint8"
  521. )
  522. def forward(self, x_orig):
  523. return self._forward(x_orig)
  524. def _forward(self, x_orig):
  525. if x_orig.numel() == 0:
  526. return x_orig
  527. x = x_orig.detach() # avoid keeping autograd tape
  528. min_val = self.min_val
  529. max_val = self.max_val
  530. x_dim = x.size()
  531. new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
  532. new_axis_list[self.ch_axis] = 0
  533. new_axis_list[0] = self.ch_axis
  534. y = x.permute(new_axis_list)
  535. # Need to match dtype of min/max because the updates to buffers
  536. # are done in place and types need to match for comparisons
  537. y = y.to(self.min_val.dtype)
  538. y = torch.flatten(y, start_dim=1)
  539. if min_val.numel() == 0 or max_val.numel() == 0:
  540. min_val, max_val = torch.aminmax(y, dim=1)
  541. else:
  542. min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
  543. min_val = torch.min(min_val_cur, min_val)
  544. max_val = torch.max(max_val_cur, max_val)
  545. self.min_val.resize_(min_val.shape)
  546. self.max_val.resize_(max_val.shape)
  547. self.min_val.copy_(min_val)
  548. self.max_val.copy_(max_val)
  549. return x_orig
  550. @torch.jit.export
  551. def calculate_qparams(self):
  552. return self._calculate_qparams(self.min_val, self.max_val)
  553. def extra_repr(self):
  554. return "min_val={}, max_val={}".format(self.min_val, self.max_val)
  555. def _load_from_state_dict(
  556. self,
  557. state_dict: Dict[str, Any],
  558. prefix: str,
  559. local_metadata: Dict[str, torch.Tensor],
  560. strict: bool,
  561. missing_keys: List[str],
  562. unexpected_keys: List[str],
  563. error_msgs: List[str],
  564. ):
  565. version = local_metadata.get("version", None)
  566. if version is None or version < 3:
  567. local_state = ["min_vals", "max_vals"]
  568. expected_min_name = "min_vals"
  569. expected_max_name = "max_vals"
  570. else:
  571. local_state = ["min_val", "max_val"]
  572. expected_min_name = "min_val"
  573. expected_max_name = "max_val"
  574. for name in local_state:
  575. key = prefix + name
  576. if key in state_dict:
  577. val = state_dict[key]
  578. # Custom handling to allow loading min_val or max_val
  579. # of size N into uninitialized buffers of size 0. The
  580. # buffers are resized here, and the values are copied in
  581. # the default state_dict loading code of the parent.
  582. if name == expected_min_name:
  583. self.min_val.resize_(val.shape)
  584. elif name == expected_max_name:
  585. self.max_val.resize_(val.shape)
  586. else:
  587. warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name))
  588. # For torchscript module we need to update the attributes here since we do not
  589. # call the `_load_from_state_dict` function defined module.py
  590. if torch.jit.is_scripting():
  591. if name == expected_min_name:
  592. self.min_val.copy_(val)
  593. elif name == expected_max_name:
  594. self.max_val.copy_(val)
  595. else:
  596. warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name))
  597. elif strict:
  598. missing_keys.append(key)
  599. if not torch.jit.is_scripting():
  600. super(PerChannelMinMaxObserver, self)._load_from_state_dict(
  601. state_dict,
  602. prefix,
  603. local_metadata,
  604. False,
  605. missing_keys,
  606. unexpected_keys,
  607. error_msgs,
  608. )
  609. def _load_from_state_dict_script(
  610. self,
  611. state_dict: Dict[str, Any],
  612. prefix: str,
  613. local_metadata: Dict[str, torch.Tensor],
  614. strict: bool,
  615. missing_keys: List[str],
  616. unexpected_keys: List[str],
  617. error_msgs: List[str],
  618. ):
  619. self._load_from_state_dict(
  620. state_dict,
  621. prefix,
  622. local_metadata,
  623. strict,
  624. missing_keys,
  625. unexpected_keys,
  626. error_msgs,
  627. )
  628. @torch.jit.export
  629. def reset_min_max_vals(self):
  630. """Resets the min/max values."""
  631. self.min_val = torch.tensor([])
  632. self.max_val = torch.tensor([])
  633. class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
  634. r"""Observer module for computing the quantization parameters based on the
  635. running per channel min and max values.
  636. This observer uses the tensor min/max statistics to compute the per channel
  637. quantization parameters. The module records the running minimum and maximum
  638. of incoming tensors, and uses this statistic to compute the quantization
  639. parameters.
  640. Args:
  641. averaging_constant: Averaging constant for min/max.
  642. ch_axis: Channel axis
  643. dtype: Quantized data type
  644. qscheme: Quantization scheme to be used
  645. reduce_range: Reduces the range of the quantized data type by 1 bit
  646. quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
  647. quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
  648. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  649. The quantization parameters are computed the same way as in
  650. :class:`~torch.ao.quantization.observer.MovingAverageMinMaxObserver`, with the
  651. difference that the running min/max values are stored per channel.
  652. Scales and zero points are thus computed per channel as well.
  653. .. note:: If the running minimum equals to the running maximum, the scales
  654. and zero_points are set to 1.0 and 0.
  655. """
  656. def __init__(
  657. self,
  658. averaging_constant=0.01,
  659. ch_axis=0,
  660. dtype=torch.quint8,
  661. qscheme=torch.per_channel_affine,
  662. reduce_range=False,
  663. quant_min=None,
  664. quant_max=None,
  665. eps=torch.finfo(torch.float32).eps,
  666. **kwargs
  667. ) -> None:
  668. super(MovingAveragePerChannelMinMaxObserver, self).__init__(
  669. ch_axis=ch_axis,
  670. dtype=dtype,
  671. qscheme=qscheme,
  672. reduce_range=reduce_range,
  673. quant_min=quant_min,
  674. quant_max=quant_max,
  675. eps=eps,
  676. **kwargs
  677. )
  678. self.averaging_constant = averaging_constant
  679. def forward(self, x_orig):
  680. if x_orig.numel() == 0:
  681. return x_orig
  682. x = x_orig.detach() # avoid keeping autograd tape
  683. x = x.to(self.min_val.dtype)
  684. min_val = self.min_val
  685. max_val = self.max_val
  686. x_dim = x.size()
  687. new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
  688. new_axis_list[self.ch_axis] = 0
  689. new_axis_list[0] = self.ch_axis
  690. y = x.permute(new_axis_list)
  691. y = torch.flatten(y, start_dim=1)
  692. if min_val.numel() == 0 or max_val.numel() == 0:
  693. min_val, max_val = torch.aminmax(y, dim=1)
  694. else:
  695. min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
  696. min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
  697. max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
  698. self.min_val.resize_(min_val.shape)
  699. self.max_val.resize_(max_val.shape)
  700. self.min_val.copy_(min_val)
  701. self.max_val.copy_(max_val)
  702. return x_orig
  703. class HistogramObserver(UniformQuantizationObserverBase):
  704. r"""
  705. The module records the running histogram of tensor values along with
  706. min/max values. ``calculate_qparams`` will calculate scale and zero_point.
  707. Args:
  708. bins: Number of bins to use for the histogram
  709. upsample_rate: Factor by which the histograms are upsampled, this is
  710. used to interpolate histograms with varying ranges across observations
  711. dtype: Quantized data type
  712. qscheme: Quantization scheme to be used
  713. reduce_range: Reduces the range of the quantized data type by 1 bit
  714. eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.
  715. The scale and zero point are computed as follows:
  716. 1. Create the histogram of the incoming inputs.
  717. The histogram is computed continuously, and the ranges per bin change
  718. with every new tensor observed.
  719. 2. Search the distribution in the histogram for optimal min/max values.
  720. The search for the min/max values ensures the minimization of the
  721. quantization error with respect to the floating point model.
  722. 3. Compute the scale and zero point the same way as in the
  723. :class:`~torch.ao.quantization.MinMaxObserver`
  724. """
  725. histogram: torch.Tensor
  726. min_val: torch.Tensor
  727. max_val: torch.Tensor
  728. def __init__(
  729. self,
  730. bins: int = 2048,
  731. upsample_rate: int = 128,
  732. dtype: torch.dtype = torch.quint8,
  733. qscheme=torch.per_tensor_affine,
  734. reduce_range=False,
  735. quant_min=None,
  736. quant_max=None,
  737. factory_kwargs=None,
  738. eps=torch.finfo(torch.float32).eps,
  739. ) -> None:
  740. # bins: The number of bins used for histogram calculation.
  741. super(HistogramObserver, self).__init__(
  742. dtype=dtype,
  743. qscheme=qscheme,
  744. reduce_range=reduce_range,
  745. quant_min=quant_min,
  746. quant_max=quant_max,
  747. factory_kwargs=factory_kwargs,
  748. eps=eps,
  749. )
  750. factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
  751. self.bins = bins
  752. self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
  753. self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
  754. self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
  755. self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
  756. self.upsample_rate = upsample_rate
  757. def _get_norm(
  758. self, delta_begin: torch.Tensor, delta_end: torch.Tensor, density: torch.Tensor
  759. ) -> torch.Tensor:
  760. r"""
  761. Compute the norm of the values uniformaly distributed between
  762. delta_begin and delta_end.
  763. Currently only L2 norm is supported.
  764. norm = density * (integral_{begin, end} x^2)
  765. = density * (end^3 - begin^3) / 3
  766. """
  767. norm = (
  768. delta_end * delta_end * delta_end - delta_begin * delta_begin * delta_begin
  769. ) / 3
  770. return density * norm
  771. def _compute_quantization_error(self, next_start_bin: int, next_end_bin: int):
  772. r"""
  773. Compute the quantization error if we use start_bin to end_bin as the
  774. min and max to do the quantization.
  775. """
  776. bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
  777. dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  778. if dst_bin_width == 0.0:
  779. return 0.0
  780. src_bin = torch.arange(self.bins, device=self.histogram.device)
  781. # distances from the beginning of first dst_bin to the beginning and
  782. # end of src_bin
  783. src_bin_begin = (src_bin - next_start_bin) * bin_width
  784. src_bin_end = src_bin_begin + bin_width
  785. # which dst_bins the beginning and end of src_bin belong to?
  786. dst_bin_of_begin = torch.clamp(
  787. torch.div(src_bin_begin, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1
  788. )
  789. dst_bin_of_begin_center = (dst_bin_of_begin + 0.5) * dst_bin_width
  790. dst_bin_of_end = torch.clamp(
  791. torch.div(src_bin_end, dst_bin_width, rounding_mode='floor'), 0, self.dst_nbins - 1
  792. )
  793. dst_bin_of_end_center = (dst_bin_of_end + 0.5) * dst_bin_width
  794. density = self.histogram / bin_width
  795. norm = torch.zeros(self.bins, device=self.histogram.device)
  796. delta_begin = src_bin_begin - dst_bin_of_begin_center
  797. delta_end = dst_bin_width / 2
  798. norm += self._get_norm(delta_begin,
  799. torch.ones(self.bins, device=self.histogram.device) * delta_end,
  800. density)
  801. norm += (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm(
  802. torch.tensor(-dst_bin_width / 2), torch.tensor(dst_bin_width / 2), density
  803. )
  804. dst_bin_of_end_center = dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  805. delta_begin = -dst_bin_width / 2
  806. delta_end = src_bin_end - dst_bin_of_end_center
  807. norm += self._get_norm(torch.tensor(delta_begin), delta_end, density)
  808. return norm.sum().item()
  809. def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]:
  810. r"""Non-linear parameter search.
  811. An approximation for L2 error minimization for selecting min/max.
  812. By selecting new min/max, we filter out outliers in input distribution.
  813. This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
  814. caffe2/quantization/server/norm_minimization.cc
  815. """
  816. assert self.histogram.size()[0] == self.bins, "bins mistmatch"
  817. bin_width = (self.max_val - self.min_val) / self.bins
  818. # cumulative sum
  819. total = torch.sum(self.histogram).item()
  820. cSum = torch.cumsum(self.histogram, dim=0)
  821. stepsize = 1e-5 # granularity
  822. alpha = 0.0 # lower bound
  823. beta = 1.0 # upper bound
  824. start_bin = 0
  825. end_bin = self.bins - 1
  826. norm_min = float("inf")
  827. while alpha < beta:
  828. # Find the next step
  829. next_alpha = alpha + stepsize
  830. next_beta = beta - stepsize
  831. # find the left and right bins between the quantile bounds
  832. l = start_bin
  833. r = end_bin
  834. while l < end_bin and cSum[l] < next_alpha * total:
  835. l = l + 1
  836. while r > start_bin and cSum[r] > next_beta * total:
  837. r = r - 1
  838. # decide the next move
  839. next_start_bin = start_bin
  840. next_end_bin = end_bin
  841. if (l - start_bin) > (end_bin - r):
  842. # move the start bin
  843. next_start_bin = l
  844. alpha = next_alpha
  845. else:
  846. # move the end bin
  847. next_end_bin = r
  848. beta = next_beta
  849. if next_start_bin == start_bin and next_end_bin == end_bin:
  850. continue
  851. # calculate the quantization error using next_start_bin and next_end_bin
  852. norm = self._compute_quantization_error(next_start_bin, next_end_bin)
  853. if norm > norm_min:
  854. break
  855. norm_min = norm
  856. start_bin = next_start_bin
  857. end_bin = next_end_bin
  858. new_min = self.min_val + bin_width * start_bin
  859. new_max = self.min_val + bin_width * (end_bin + 1)
  860. return new_min, new_max
  861. def _adjust_min_max(
  862. self, combined_min: torch.Tensor, combined_max: torch.Tensor, upsample_rate: int
  863. ) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
  864. # We ensure that:
  865. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  866. # This allows us to have a common grid of resolution s, where we can align
  867. # the input histogram
  868. # start_idx maps min_val to the histogram bin index.
  869. hist_bin_width = (self.max_val - self.min_val) / (self.bins * upsample_rate)
  870. downsample_rate = int(
  871. torch.ceil(
  872. (combined_max - combined_min) / (self.bins * hist_bin_width)
  873. ).item()
  874. )
  875. e = downsample_rate * (self.bins * hist_bin_width) - (
  876. combined_max - combined_min
  877. )
  878. # Relax only the max, not the min, so that for one sided distributions, min stays at zero
  879. combined_max = combined_max + e
  880. combined_min = combined_min
  881. start_idx = int(
  882. torch.round((self.min_val - combined_min) / hist_bin_width).item()
  883. )
  884. return combined_min, combined_max, downsample_rate, start_idx
  885. def _combine_histograms(
  886. self,
  887. orig_hist: torch.Tensor,
  888. new_hist: torch.Tensor,
  889. upsample_rate: int,
  890. downsample_rate: int,
  891. start_idx: int,
  892. Nbins: int,
  893. ) -> torch.Tensor:
  894. # First up-sample the histogram with new data by a factor of L
  895. # This creates an approximate probability density thats piecwise constant
  896. upsampled_histogram = new_hist.repeat_interleave(upsample_rate)
  897. # Now insert the upsampled histogram into the output
  898. # histogram, which is initialized with zeros.
  899. # The offset at which the histogram is introduced is determined
  900. # by the start index as the output histogram can cover a wider range
  901. histogram_with_output_range = torch.zeros(
  902. (Nbins * downsample_rate), device=orig_hist.device
  903. )
  904. histogram_with_output_range[
  905. start_idx : Nbins * upsample_rate + start_idx
  906. ] = upsampled_histogram
  907. # Compute integral histogram, double precision is needed to ensure
  908. # that there are no overflows
  909. integral_histogram = torch.cumsum(
  910. histogram_with_output_range, 0, dtype=torch.double
  911. )[downsample_rate - 1 :: downsample_rate]
  912. # Finally perform interpolation
  913. shifted_integral_histogram = torch.zeros((Nbins), device=orig_hist.device)
  914. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  915. interpolated_histogram = (
  916. integral_histogram - shifted_integral_histogram
  917. ) / upsample_rate
  918. orig_hist = orig_hist + interpolated_histogram.to(torch.float)
  919. return orig_hist
  920. def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
  921. if x_orig.numel() == 0:
  922. return x_orig
  923. x = x_orig.detach()
  924. min_val = self.min_val
  925. max_val = self.max_val
  926. same_values = min_val.item() == max_val.item()
  927. is_uninitialized = min_val == float("inf") and max_val == float("-inf")
  928. if is_uninitialized or same_values:
  929. min_val, max_val = torch.aminmax(x)
  930. self.min_val.resize_(min_val.shape)
  931. self.min_val.copy_(min_val)
  932. self.max_val.resize_(max_val.shape)
  933. self.max_val.copy_(max_val)
  934. assert (
  935. min_val.numel() == 1 and max_val.numel() == 1
  936. ), "histogram min/max values must be scalar."
  937. torch.histc(
  938. x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
  939. )
  940. else:
  941. new_min, new_max = torch.aminmax(x)
  942. combined_min = torch.min(new_min, min_val)
  943. combined_max = torch.max(new_max, max_val)
  944. # combine the existing histogram and new histogram into 1 histogram
  945. # We do this by first upsampling the histogram to a dense grid
  946. # and then downsampling the histogram efficiently
  947. (
  948. combined_min,
  949. combined_max,
  950. downsample_rate,
  951. start_idx,
  952. ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
  953. assert (
  954. combined_min.numel() == 1 and combined_max.numel() == 1
  955. ), "histogram min/max values must be scalar."
  956. combined_histogram = torch.histc(
  957. x, self.bins, min=int(combined_min), max=int(combined_max)
  958. )
  959. if combined_min == min_val and combined_max == max_val:
  960. combined_histogram += self.histogram
  961. else:
  962. combined_histogram = self._combine_histograms(
  963. combined_histogram,
  964. self.histogram,
  965. self.upsample_rate,
  966. downsample_rate,
  967. start_idx,
  968. self.bins,
  969. )
  970. self.histogram.detach_().resize_(combined_histogram.shape)
  971. self.histogram.copy_(combined_histogram)
  972. self.min_val.detach_().resize_(combined_min.shape)
  973. self.min_val.copy_(combined_min)
  974. self.max_val.detach_().resize_(combined_max.shape)
  975. self.max_val.copy_(combined_max)
  976. return x_orig
  977. @torch.jit.export
  978. def calculate_qparams(self):
  979. is_uninitialized = self.min_val == float("inf") and self.max_val == float(
  980. "-inf"
  981. )
  982. if is_uninitialized:
  983. warnings.warn(
  984. "must run observer before calling calculate_qparams.\
  985. Returning default scale and zero point "
  986. )
  987. return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor([0], device=self.min_val.device.type)
  988. assert self.bins == len(self.histogram), (
  989. "The number of bins in histogram should be equal to the number of bins "
  990. "supplied while making this observer"
  991. )
  992. new_min, new_max = self._non_linear_param_search()
  993. return self._calculate_qparams(new_min, new_max)
  994. def _save_to_state_dict(self, destination, prefix, keep_vars):
  995. super(HistogramObserver, self)._save_to_state_dict(
  996. destination, prefix, keep_vars
  997. )
  998. destination[prefix + "min_val"] = self.min_val
  999. destination[prefix + "max_val"] = self.max_val
  1000. def _load_from_state_dict(
  1001. self,
  1002. state_dict,
  1003. prefix,
  1004. local_metadata,
  1005. strict,
  1006. missing_keys,
  1007. unexpected_keys,
  1008. error_msgs,
  1009. ):
  1010. version = local_metadata.get("version", None)
  1011. if version is None or version < 3:
  1012. # if min_val and max_val are not initialized, update their shape
  1013. # to account for the differences between v2 and v3
  1014. min_val_name, max_val_name = prefix + "min_val", prefix + "max_val"
  1015. if min_val_name in state_dict:
  1016. if state_dict[min_val_name].shape == torch.Size([0]):
  1017. state_dict[min_val_name] = torch.tensor(float("inf"))
  1018. if max_val_name in state_dict:
  1019. if state_dict[max_val_name].shape == torch.Size([0]):
  1020. state_dict[max_val_name] = torch.tensor(float("-inf"))
  1021. local_state = ["min_val", "max_val"]
  1022. for name in local_state:
  1023. key = prefix + name
  1024. if key in state_dict:
  1025. val = state_dict[key]
  1026. setattr(self, name, val)
  1027. elif strict:
  1028. missing_keys.append(key)
  1029. super(HistogramObserver, self)._load_from_state_dict(
  1030. state_dict,
  1031. prefix,
  1032. local_metadata,
  1033. strict,
  1034. missing_keys,
  1035. unexpected_keys,
  1036. error_msgs,
  1037. )
  1038. class FixedQParamsObserver(ObserverBase):
  1039. r"""
  1040. Observer that simulates quantize and dequantize with fixed
  1041. quantization parameters in training time. Only per tensor
  1042. quantization is supported.
  1043. Args:
  1044. `scale` (float): fixed scale for the observer
  1045. `zero_point` (int): fixed zero point for the observer
  1046. `dtype`, `qscheme`, `quant_min`, `quant_max`
  1047. """
  1048. scale: torch.Tensor
  1049. zero_point: torch.Tensor
  1050. def __init__(self,
  1051. scale,
  1052. zero_point,
  1053. dtype=torch.quint8,
  1054. qscheme=torch.per_tensor_affine,
  1055. quant_min=0,
  1056. quant_max=255):
  1057. super(FixedQParamsObserver, self).__init__(dtype=dtype)
  1058. self.quant_min = quant_min
  1059. self.quant_max = quant_max
  1060. self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
  1061. self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
  1062. self.dtype = dtype
  1063. self.qscheme = qscheme
  1064. def forward(self, X):
  1065. return X
  1066. @torch.jit.export
  1067. def calculate_qparams(self):
  1068. return self.scale, self.zero_point
  1069. class PlaceholderObserver(ObserverBase):
  1070. r"""
  1071. Observer that doesn't do anything and just passes its configuration to the
  1072. quantized module's ``.from_float()``.
  1073. Can be used for quantization to float16 which doesn't require determining
  1074. ranges.
  1075. Args:
  1076. dtype: Quantized data type
  1077. custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
  1078. (Can be used in Graph Mode Passes for special case ops).
  1079. """
  1080. def __init__(
  1081. self, dtype=torch.float32, custom_op_name="", compute_dtype=None
  1082. ) -> None:
  1083. super(PlaceholderObserver, self).__init__(dtype=dtype)
  1084. # dtype of input of the target operator, e.g. for dynamic quantization
  1085. # ops, the dtype will be float32
  1086. self.dtype = dtype
  1087. self.custom_op = custom_op_name
  1088. # used for configuration of computation type for dynamic quantization
  1089. if compute_dtype:
  1090. self.compute_dtype = compute_dtype
  1091. def forward(self, x):
  1092. return x
  1093. @torch.jit.export
  1094. def calculate_qparams(self):
  1095. raise Exception(
  1096. "calculate_qparams should not be called for PlaceholderObserver"
  1097. )
  1098. class RecordingObserver(ObserverBase):
  1099. r"""
  1100. The module is mainly for debug and records the tensor values during runtime.
  1101. Args:
  1102. dtype: Quantized data type
  1103. qscheme: Quantization scheme to be used
  1104. reduce_range: Reduces the range of the quantized data type by 1 bit
  1105. """
  1106. __annotations__ = {"tensor_val": List[Optional[torch.Tensor]]}
  1107. def __init__(self, dtype=torch.quint8, **kwargs):
  1108. super(RecordingObserver, self).__init__(dtype=dtype, **kwargs) # type: ignore[call-arg]
  1109. self.tensor_val = []
  1110. def forward(self, x):
  1111. self.tensor_val.append(x.clone())
  1112. return x
  1113. @torch.jit.export
  1114. def calculate_qparams(self):
  1115. raise Exception("calculate_qparams should not be called for RecordingObserver")
  1116. @torch.jit.export
  1117. def get_tensor_value(self):
  1118. return self.tensor_val
  1119. class NoopObserver(ObserverBase):
  1120. r"""
  1121. Observer that doesn't do anything and just passes its configuration to the
  1122. quantized module's ``.from_float()``.
  1123. Primarily used for quantization to float16 which doesn't require determining
  1124. ranges.
  1125. Args:
  1126. dtype: Quantized data type
  1127. custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
  1128. (Can be used in Graph Mode Passes for special case ops).
  1129. """
  1130. def __init__(self, dtype=torch.float16, custom_op_name="") -> None:
  1131. super(NoopObserver, self).__init__(dtype=dtype)
  1132. self.dtype = dtype
  1133. self.custom_op = custom_op_name
  1134. def forward(self, x):
  1135. return x
  1136. @torch.jit.export
  1137. def calculate_qparams(self):
  1138. raise Exception("calculate_qparams should not be called for NoopObserver")
  1139. class ReuseInputObserver(ObserverBase):
  1140. r""" This observer is used when we want to reuse the observer from the operator
  1141. that produces the input Tensor, typically used for operators like reshape, e.g.
  1142. ```
  1143. x0 = ...
  1144. x1 = x0.reshape()
  1145. ```
  1146. if we configure x0 to be observed by some observer, let's say MinMaxObserver,
  1147. and reshape is configured with ReuseInputObserver, we'll reuse the observer instance
  1148. for x0 for x1 (output of reshape). If x0 is not observed, we also won't observe x1.
  1149. Note: this is only enabled in FX Graph Mode Quantization
  1150. """
  1151. def __init__(self):
  1152. super().__init__(torch.quint8)
  1153. def forward(self, x):
  1154. return x
  1155. @torch.jit.export
  1156. def calculate_qparams(self):
  1157. raise Exception("calculate_qparams should not be called for ReuseInputObserver")
  1158. def _is_observer_script_module(mod, obs_type_name):
  1159. """Returns true if given mod is an instance of Observer script module."""
  1160. if isinstance(mod, torch.jit.RecursiveScriptModule):
  1161. # qualified name looks like '__torch__.torch.ao.quantization.observer.___torch_mangle_2.MinMaxObserver'
  1162. suffix = mod._c.qualified_name.split(".", 1)[1]
  1163. name = re.sub(r"\.___torch_mangle_\d+", "", suffix)
  1164. return obs_type_name in name
  1165. return False
  1166. def _is_activation_post_process(module):
  1167. return (
  1168. isinstance(module, torch.ao.quantization.ObserverBase)
  1169. or isinstance(module, torch.ao.quantization.FakeQuantize)
  1170. or _is_observer_script_module(module, "quantization.observer")
  1171. )
  1172. def _is_per_channel_script_obs_instance(module):
  1173. if isinstance(module, torch.jit.RecursiveScriptModule):
  1174. return _is_observer_script_module(
  1175. module, "quantization.observer.PerChannelMinMaxObserver"
  1176. ) or _is_observer_script_module(
  1177. module, "quantization.observer.MovingAveragePerChannelMinMaxObserver"
  1178. )
  1179. return False
  1180. def get_observer_state_dict(mod):
  1181. r"""
  1182. Returns the state dict corresponding to the observer stats.
  1183. Traverse the model state_dict and extract out the stats.
  1184. """
  1185. od = OrderedDict()
  1186. if isinstance(mod, torch.jit.RecursiveScriptModule):
  1187. for k, v in mod.state_dict().items():
  1188. if "observer" in k:
  1189. od[k] = v
  1190. else:
  1191. # path for GraphModule and nn.Module (eager mode)
  1192. for k, v in mod.state_dict().items():
  1193. if "activation_post_process" in k:
  1194. od[k] = v
  1195. od._metadata = mod.state_dict()._metadata # type: ignore[attr-defined]
  1196. return od
  1197. def load_observer_state_dict(mod, obs_dict):
  1198. r"""
  1199. Given input model and a state_dict containing model observer stats,
  1200. load the stats back into the model. The observer state_dict can be saved
  1201. using torch.ao.quantization.get_observer_state_dict
  1202. """
  1203. missing_keys: List[str] = []
  1204. unexpected_keys: List[str] = []
  1205. for name, module in mod.named_modules():
  1206. prefix = name + "."
  1207. if _is_activation_post_process(module):
  1208. if _is_per_channel_script_obs_instance(module):
  1209. # For per-channel observers we need to call a custom load_from_state_dict to resize the tensor.
  1210. # However this is not called when the module is scripted and we end up calling the default one in module.py
  1211. module._load_from_state_dict_script(
  1212. obs_dict, prefix, {}, True, missing_keys, unexpected_keys, []
  1213. )
  1214. else:
  1215. module._load_from_state_dict(
  1216. obs_dict, prefix, {}, False, missing_keys, unexpected_keys, []
  1217. )
  1218. for k in missing_keys:
  1219. if "observer" in k or "activation_post_process" in k:
  1220. raise Exception("Missing keys for observer {} in state_dict".format(k))
  1221. for k in unexpected_keys:
  1222. if "observer" in k or "activation_post_process" in k:
  1223. raise Exception("Unexpected keys for observer {} in state_dict".format(k))
  1224. # Restrict activations to be in the range (0,127)
  1225. default_observer = MinMaxObserver.with_args(quant_min=0, quant_max=127)
  1226. """
  1227. Default observer for static quantization, usually used for debugging.
  1228. """
  1229. default_placeholder_observer = PlaceholderObserver
  1230. """
  1231. Default placeholder observer, usually used for quantization to torch.float16.
  1232. """
  1233. default_debug_observer = RecordingObserver
  1234. """
  1235. Default debug-only observer.
  1236. """
  1237. default_weight_observer = MinMaxObserver.with_args(
  1238. dtype=torch.qint8, qscheme=torch.per_tensor_symmetric
  1239. )
  1240. """
  1241. Default weight observer.
  1242. """
  1243. weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args(
  1244. dtype=torch.qint8, qscheme=torch.per_tensor_symmetric,
  1245. quant_min=-127, quant_max=127, eps=2 ** -12)
  1246. """
  1247. Symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
  1248. """
  1249. default_histogram_observer = HistogramObserver.with_args(quant_min=0, quant_max=127)
  1250. """
  1251. Default histogram observer, usually used for PTQ.
  1252. """
  1253. default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
  1254. dtype=torch.qint8, qscheme=torch.per_channel_symmetric
  1255. )
  1256. """
  1257. Default per-channel weight observer, usually used on backends where per-channel
  1258. weight quantization is supported, such as `fbgemm`.
  1259. """
  1260. per_channel_weight_observer_range_neg_127_to_127 = MinMaxObserver.with_args(
  1261. dtype=torch.qint8, qscheme=torch.per_channel_symmetric,
  1262. quant_min=-127, quant_max=127, eps=2 ** -12)
  1263. """
  1264. Per-channel, symmetric weight observer with the 8-bit values restricted to [-127, +127], excluding -128.
  1265. """
  1266. default_dynamic_quant_observer = PlaceholderObserver.with_args(
  1267. dtype=torch.float, compute_dtype=torch.quint8
  1268. )
  1269. """
  1270. Default observer for dynamic quantization.
  1271. """
  1272. default_float_qparams_observer = PerChannelMinMaxObserver.with_args(
  1273. dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
  1274. )
  1275. """
  1276. Default observer for a floating point zero-point.
  1277. """
  1278. default_float_qparams_observer_4bit = PerChannelMinMaxObserver.with_args(
  1279. dtype=torch.quint4x2, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0
  1280. )
  1281. """
  1282. Default observer for a floating point zero-point and 4 bit activations.
  1283. """
  1284. # TODO(future PR): remove these defaults and enforce activation functions
  1285. # to explicitly specify their output range
  1286. default_fixed_qparams_range_neg1to1_observer = FixedQParamsObserver.with_args(
  1287. scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
  1288. default_fixed_qparams_range_0to1_observer = FixedQParamsObserver.with_args(
  1289. scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
  1290. # TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
  1291. default_symmetric_fixed_qparams_observer = default_fixed_qparams_range_neg1to1_observer
  1292. default_affine_fixed_qparams_observer = default_fixed_qparams_range_0to1_observer
  1293. """
  1294. Default observers for fixed qparams operations.
  1295. """
  1296. default_reuse_input_observer = ReuseInputObserver
  1297. """
  1298. Default observer for operators like reshape that reuses the observer of input to
  1299. the operator
  1300. """