utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. """
  2. Utils shared by different modes of quantization (eager/graph)
  3. """
  4. import warnings
  5. import functools
  6. import torch
  7. from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
  8. from typing import Tuple, Any, Union, Callable
  9. from torch.nn.utils.parametrize import is_parametrized
  10. # Type for fusion patterns, it can be more complicated than the following actually,
  11. # see pattern.md for docs
  12. # TODO: not sure if typing supports recursive data types
  13. Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
  14. # TODO: maybe rename this to MatchInputNode
  15. class MatchAllNode:
  16. """ A node pattern that matches all nodes, used in defining
  17. fusion patterns in FX Graph Mode Quantization
  18. """
  19. pass
  20. module_type_list = {
  21. torch.nn.ReLU,
  22. torch.nn.ReLU6,
  23. torch.nn.AdaptiveAvgPool1d,
  24. torch.nn.AdaptiveAvgPool2d,
  25. torch.nn.AdaptiveAvgPool3d,
  26. torch.nn.AvgPool1d,
  27. torch.nn.AvgPool2d,
  28. torch.nn.AvgPool3d,
  29. torch.nn.MaxPool1d,
  30. torch.nn.MaxPool2d,
  31. torch.nn.MaxPool3d,
  32. torch.nn.Identity,
  33. torch.nn.Hardsigmoid,
  34. torch.nn.Sigmoid,
  35. torch.nn.Tanh,
  36. }
  37. func_list = {
  38. torch.nn.functional.adaptive_avg_pool1d,
  39. torch.nn.functional.adaptive_avg_pool2d,
  40. torch.nn.functional.adaptive_avg_pool3d,
  41. torch.nn.functional.elu,
  42. torch.nn.functional.hardswish,
  43. torch.nn.functional.instance_norm,
  44. torch.nn.functional.layer_norm,
  45. torch.nn.functional.leaky_relu,
  46. torch.nn.functional.silu,
  47. torch.nn.functional.mish,
  48. torch.nn.functional.dropout,
  49. torch.nn.functional.max_pool1d,
  50. torch.nn.functional.max_pool2d,
  51. torch.nn.functional.max_pool3d,
  52. torch.nn.functional.relu,
  53. torch.nn.functional.hardtanh,
  54. torch.nn.functional.hardtanh_,
  55. torch.nn.functional.hardsigmoid,
  56. torch.nn.functional.sigmoid,
  57. torch.transpose,
  58. torch.repeat_interleave,
  59. torch.sigmoid,
  60. torch.squeeze,
  61. torch.stack,
  62. torch.sum,
  63. torch.tanh,
  64. torch.unsqueeze,
  65. torch.cat,
  66. }
  67. method_list = {
  68. torch.mean,
  69. 'relu',
  70. 'relu_',
  71. 'contiguous',
  72. 'detach',
  73. 'detach_',
  74. 'hardsigmoid',
  75. 'hardsigmoid_',
  76. 'permute',
  77. 'repeat',
  78. 'repeat_interleave',
  79. 'reshape',
  80. 'resize_',
  81. 'shape',
  82. 'sigmoid',
  83. 'sigmoid_',
  84. 'size',
  85. 'squeeze',
  86. 'squeeze_',
  87. 'tanh',
  88. 'tanh_',
  89. 'transpose',
  90. 'unsqueeze',
  91. 'unsqueeze_',
  92. 'view',
  93. }
  94. def check_node(node, modules):
  95. # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
  96. is_call_function = node.op == "call_function" and node.target in func_list
  97. is_call_method = node.op == "call_method" and node.target in method_list
  98. is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
  99. return is_call_function, is_call_method, is_call_module
  100. def get_combined_dict(default_dict, additional_dict):
  101. d = default_dict.copy()
  102. d.update(additional_dict)
  103. return d
  104. def is_per_tensor(qscheme):
  105. return qscheme == torch.per_tensor_affine or \
  106. qscheme == torch.per_tensor_symmetric
  107. def is_per_channel(qscheme):
  108. return qscheme in [torch.per_channel_affine,
  109. torch.per_channel_affine_float_qparams,
  110. torch.per_channel_symmetric]
  111. def getattr_from_fqn(obj: Any, fqn: str) -> Any:
  112. """
  113. Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
  114. """
  115. return functools.reduce(getattr, fqn.split("."), obj)
  116. def get_qparam_dict(observer_or_fake_quant):
  117. qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
  118. dtype = observer_or_fake_quant.dtype
  119. qparams = {"qscheme": qscheme, "dtype": dtype}
  120. if not qscheme:
  121. return qparams
  122. if is_per_tensor(qscheme):
  123. qscheme = torch.per_tensor_affine
  124. elif is_per_channel(qscheme):
  125. # change symmetric to affine since we do not have symmetric
  126. # quantized Tensor
  127. if qscheme == torch.per_channel_symmetric:
  128. qscheme = torch.per_channel_affine
  129. qparams["axis"] = observer_or_fake_quant.ch_axis
  130. else:
  131. raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
  132. # update qscheme, since we don't have symmetric quant qscheme
  133. # in quantized Tensor
  134. qparams["qscheme"] = qscheme
  135. scale, zero_point = observer_or_fake_quant.calculate_qparams()
  136. qparams["scale"] = scale
  137. qparams["zero_point"] = zero_point
  138. return qparams
  139. def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
  140. """ Get the observed/quantized custom module class that we need
  141. to swap `custom_module` to
  142. Input:
  143. custom_module: input, can be an instance of either a float or observed custom module
  144. custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
  145. qconfig: qconfig configured for the custom module
  146. Output:
  147. corresponding observed/quantized custom module class for input custom module instance
  148. """
  149. quant_type = get_quant_type(qconfig)
  150. quant_type_str = quant_type_to_str(quant_type)
  151. class_mapping = custom_module_class_mapping.get(quant_type_str, {})
  152. assert type(custom_module) in class_mapping, "did not find corresponding observed " \
  153. "module class for {} in mapping: {}".format(type(custom_module), class_mapping)
  154. return class_mapping[type(custom_module)]
  155. def activation_dtype(qconfig):
  156. assert qconfig is not None
  157. activation = qconfig.activation()
  158. return activation.dtype
  159. def weight_dtype(qconfig):
  160. assert qconfig is not None
  161. weight = qconfig.weight()
  162. return weight.dtype
  163. def activation_is_statically_quantized(qconfig):
  164. """ Given a qconfig, decide if the activation needs to be
  165. quantized or not, this includes quantizing to quint8, qint8 and float16
  166. """
  167. return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
  168. def activation_is_dynamically_quantized(qconfig):
  169. """ Given a qconfig, decide if the activation needs to be
  170. dynamically quantized or not, this includes dynamically quantizing to
  171. quint8, qint8 and float16
  172. """
  173. activation_dtype, _, activation_compute_dtype = \
  174. get_qconfig_dtypes(qconfig)
  175. return activation_dtype == torch.float and \
  176. activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16]
  177. def activation_is_int8_quantized(qconfig):
  178. """ Given a qconfig, decide if the activation needs to be
  179. quantized to int8 or not, this includes quantizing to quint8, qint8
  180. """
  181. return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
  182. def activation_is_int32_quantized(qconfig):
  183. """ Given a qconfig, decide if the activation needs to be
  184. quantized to int32 or not
  185. """
  186. return activation_dtype(qconfig) == torch.qint32
  187. def weight_is_quantized(qconfig):
  188. """ Given a qconfig, decide if the weight needs to be
  189. quantized or not
  190. """
  191. return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2]
  192. def weight_is_statically_quantized(qconfig):
  193. """ Given a qconfig, decide if the weight needs to be statically
  194. quantized or not
  195. """
  196. return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
  197. def op_is_int8_dynamically_quantized(qconfig) -> bool:
  198. """ Given a qconfig, returns True if this op is using int8 dynamic
  199. quantization
  200. """
  201. activation_dtype, weight_dtype, activation_compute_dtype = \
  202. get_qconfig_dtypes(qconfig)
  203. return (
  204. activation_dtype is torch.float and
  205. # for now, the lines below assume fbgemm or qnnpack
  206. weight_dtype is torch.qint8 and
  207. activation_compute_dtype is torch.quint8
  208. )
  209. def get_qconfig_dtypes(qconfig):
  210. r""" returns the qconfig tuple for qconfig:
  211. (activation_dtype, weight_dtype, activation_compute_dtype)
  212. """
  213. assert qconfig is not None
  214. activation = qconfig.activation()
  215. weight = qconfig.weight()
  216. compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
  217. return (activation.dtype, weight.dtype, compute_dtype)
  218. def get_quant_type(qconfig):
  219. assert qconfig is not None
  220. activation = qconfig.activation()
  221. weight = qconfig.weight()
  222. static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2]
  223. if weight.dtype in static_dtypes:
  224. if activation.dtype in static_dtypes:
  225. return QuantType.STATIC
  226. elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
  227. return QuantType.DYNAMIC
  228. else:
  229. return QuantType.WEIGHT_ONLY
  230. if weight.dtype == torch.float16:
  231. if activation.dtype == torch.float:
  232. return QuantType.DYNAMIC
  233. elif activation.dtype == torch.float16:
  234. return QuantType.STATIC
  235. raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
  236. "weight({})".format(activation.dtype, weight.dtype))
  237. def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
  238. """ Checks if the given minimum and maximum values are valid, meaning that
  239. they exist and the min value is less than the max value.
  240. """
  241. if min_val.numel() == 0 or max_val.numel() == 0:
  242. warnings.warn(
  243. "must run observer before calling calculate_qparams. " +
  244. "Returning default values."
  245. )
  246. return False
  247. if min_val.dim() == 0 or max_val.dim() == 0:
  248. if min_val == float("inf") and max_val == float("-inf"):
  249. warnings.warn(
  250. "must run observer before calling calculate_qparams. " +
  251. "Returning default values."
  252. )
  253. return False
  254. assert min_val <= max_val, "min {} should be less than max {}".format(
  255. min_val, max_val
  256. )
  257. else:
  258. assert torch.all(
  259. min_val <= max_val
  260. ), "min {} should be less than max {}".format(min_val, max_val)
  261. return True
  262. def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
  263. reduce_range: bool) -> Tuple[int, int]:
  264. r"""Calculates actual qmin and qmax based on the quantization range,
  265. observer datatype and if range is reduced.
  266. """
  267. # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
  268. if has_customized_qrange:
  269. # This initialization here is to be resolve TorchScript compilation issues and allow
  270. # using of refinement to decouple initial_qmin and initial_qmax from quantization range.
  271. # The actual values of initial_qmin and initial_qmax will be reset below.
  272. if dtype == torch.qint32:
  273. initial_quant_min, initial_quant_max = 0, 2**31 - 1
  274. else:
  275. initial_quant_min, initial_quant_max = 0, 255
  276. # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
  277. # attribute from Optional valid integers for use, based on TorchScript's requirements.
  278. custom_quant_min, custom_quant_max = quant_min, quant_max
  279. if custom_quant_min is not None and custom_quant_max is not None:
  280. initial_quant_min, initial_quant_max = (
  281. custom_quant_min,
  282. custom_quant_max,
  283. )
  284. qrange_len = initial_quant_max - initial_quant_min + 1
  285. if dtype == torch.qint8:
  286. assert (
  287. 0 < qrange_len <= 256
  288. ), "quantization range should be positive and not exceed the maximum bit range (=256)."
  289. elif dtype == torch.qint32:
  290. assert (
  291. 0 < qrange_len <= 2**31
  292. ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
  293. if reduce_range:
  294. quant_min, quant_max = quant_min // 2, quant_max // 2
  295. else:
  296. # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
  297. if dtype == torch.qint8:
  298. if reduce_range:
  299. quant_min, quant_max = -64, 63
  300. else:
  301. quant_min, quant_max = -128, 127
  302. elif dtype == torch.quint8:
  303. if reduce_range:
  304. quant_min, quant_max = 0, 127
  305. else:
  306. quant_min, quant_max = 0, 255
  307. elif dtype == torch.qint32:
  308. quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
  309. else:
  310. quant_min, quant_max = 0, 15
  311. return quant_min, quant_max
  312. def _parent_name(target):
  313. """
  314. Turn 'foo.bar' into ['foo', 'bar']
  315. """
  316. r = target.rsplit('.', 1)
  317. if len(r) == 1:
  318. return '', r[0]
  319. else:
  320. return r[0], r[1]
  321. def has_no_children_ignoring_parametrizations(module):
  322. """
  323. Checks if module._modules is empty or
  324. if module is a parametrization, checks that module._modules only has
  325. the 'parametrizations' module
  326. """
  327. if len(module._modules) == 0:
  328. return True
  329. elif is_parametrized(module):
  330. return len(module._modules) == 1 and 'parametrizations' in module._modules
  331. else:
  332. return False