| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372 |
- """
- Utils shared by different modes of quantization (eager/graph)
- """
- import warnings
- import functools
- import torch
- from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
- from typing import Tuple, Any, Union, Callable
- from torch.nn.utils.parametrize import is_parametrized
- # Type for fusion patterns, it can be more complicated than the following actually,
- # see pattern.md for docs
- # TODO: not sure if typing supports recursive data types
- Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
- # TODO: maybe rename this to MatchInputNode
- class MatchAllNode:
- """ A node pattern that matches all nodes, used in defining
- fusion patterns in FX Graph Mode Quantization
- """
- pass
- module_type_list = {
- torch.nn.ReLU,
- torch.nn.ReLU6,
- torch.nn.AdaptiveAvgPool1d,
- torch.nn.AdaptiveAvgPool2d,
- torch.nn.AdaptiveAvgPool3d,
- torch.nn.AvgPool1d,
- torch.nn.AvgPool2d,
- torch.nn.AvgPool3d,
- torch.nn.MaxPool1d,
- torch.nn.MaxPool2d,
- torch.nn.MaxPool3d,
- torch.nn.Identity,
- torch.nn.Hardsigmoid,
- torch.nn.Sigmoid,
- torch.nn.Tanh,
- }
- func_list = {
- torch.nn.functional.adaptive_avg_pool1d,
- torch.nn.functional.adaptive_avg_pool2d,
- torch.nn.functional.adaptive_avg_pool3d,
- torch.nn.functional.elu,
- torch.nn.functional.hardswish,
- torch.nn.functional.instance_norm,
- torch.nn.functional.layer_norm,
- torch.nn.functional.leaky_relu,
- torch.nn.functional.silu,
- torch.nn.functional.mish,
- torch.nn.functional.dropout,
- torch.nn.functional.max_pool1d,
- torch.nn.functional.max_pool2d,
- torch.nn.functional.max_pool3d,
- torch.nn.functional.relu,
- torch.nn.functional.hardtanh,
- torch.nn.functional.hardtanh_,
- torch.nn.functional.hardsigmoid,
- torch.nn.functional.sigmoid,
- torch.transpose,
- torch.repeat_interleave,
- torch.sigmoid,
- torch.squeeze,
- torch.stack,
- torch.sum,
- torch.tanh,
- torch.unsqueeze,
- torch.cat,
- }
- method_list = {
- torch.mean,
- 'relu',
- 'relu_',
- 'contiguous',
- 'detach',
- 'detach_',
- 'hardsigmoid',
- 'hardsigmoid_',
- 'permute',
- 'repeat',
- 'repeat_interleave',
- 'reshape',
- 'resize_',
- 'shape',
- 'sigmoid',
- 'sigmoid_',
- 'size',
- 'squeeze',
- 'squeeze_',
- 'tanh',
- 'tanh_',
- 'transpose',
- 'unsqueeze',
- 'unsqueeze_',
- 'view',
- }
- def check_node(node, modules):
- # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
- is_call_function = node.op == "call_function" and node.target in func_list
- is_call_method = node.op == "call_method" and node.target in method_list
- is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
- return is_call_function, is_call_method, is_call_module
- def get_combined_dict(default_dict, additional_dict):
- d = default_dict.copy()
- d.update(additional_dict)
- return d
- def is_per_tensor(qscheme):
- return qscheme == torch.per_tensor_affine or \
- qscheme == torch.per_tensor_symmetric
- def is_per_channel(qscheme):
- return qscheme in [torch.per_channel_affine,
- torch.per_channel_affine_float_qparams,
- torch.per_channel_symmetric]
- def getattr_from_fqn(obj: Any, fqn: str) -> Any:
- """
- Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
- """
- return functools.reduce(getattr, fqn.split("."), obj)
- def get_qparam_dict(observer_or_fake_quant):
- qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
- dtype = observer_or_fake_quant.dtype
- qparams = {"qscheme": qscheme, "dtype": dtype}
- if not qscheme:
- return qparams
- if is_per_tensor(qscheme):
- qscheme = torch.per_tensor_affine
- elif is_per_channel(qscheme):
- # change symmetric to affine since we do not have symmetric
- # quantized Tensor
- if qscheme == torch.per_channel_symmetric:
- qscheme = torch.per_channel_affine
- qparams["axis"] = observer_or_fake_quant.ch_axis
- else:
- raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
- # update qscheme, since we don't have symmetric quant qscheme
- # in quantized Tensor
- qparams["qscheme"] = qscheme
- scale, zero_point = observer_or_fake_quant.calculate_qparams()
- qparams["scale"] = scale
- qparams["zero_point"] = zero_point
- return qparams
- def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
- """ Get the observed/quantized custom module class that we need
- to swap `custom_module` to
- Input:
- custom_module: input, can be an instance of either a float or observed custom module
- custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
- qconfig: qconfig configured for the custom module
- Output:
- corresponding observed/quantized custom module class for input custom module instance
- """
- quant_type = get_quant_type(qconfig)
- quant_type_str = quant_type_to_str(quant_type)
- class_mapping = custom_module_class_mapping.get(quant_type_str, {})
- assert type(custom_module) in class_mapping, "did not find corresponding observed " \
- "module class for {} in mapping: {}".format(type(custom_module), class_mapping)
- return class_mapping[type(custom_module)]
- def activation_dtype(qconfig):
- assert qconfig is not None
- activation = qconfig.activation()
- return activation.dtype
- def weight_dtype(qconfig):
- assert qconfig is not None
- weight = qconfig.weight()
- return weight.dtype
- def activation_is_statically_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- quantized or not, this includes quantizing to quint8, qint8 and float16
- """
- return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
- def activation_is_dynamically_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- dynamically quantized or not, this includes dynamically quantizing to
- quint8, qint8 and float16
- """
- activation_dtype, _, activation_compute_dtype = \
- get_qconfig_dtypes(qconfig)
- return activation_dtype == torch.float and \
- activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16]
- def activation_is_int8_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- quantized to int8 or not, this includes quantizing to quint8, qint8
- """
- return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
- def activation_is_int32_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- quantized to int32 or not
- """
- return activation_dtype(qconfig) == torch.qint32
- def weight_is_quantized(qconfig):
- """ Given a qconfig, decide if the weight needs to be
- quantized or not
- """
- return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2]
- def weight_is_statically_quantized(qconfig):
- """ Given a qconfig, decide if the weight needs to be statically
- quantized or not
- """
- return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
- def op_is_int8_dynamically_quantized(qconfig) -> bool:
- """ Given a qconfig, returns True if this op is using int8 dynamic
- quantization
- """
- activation_dtype, weight_dtype, activation_compute_dtype = \
- get_qconfig_dtypes(qconfig)
- return (
- activation_dtype is torch.float and
- # for now, the lines below assume fbgemm or qnnpack
- weight_dtype is torch.qint8 and
- activation_compute_dtype is torch.quint8
- )
- def get_qconfig_dtypes(qconfig):
- r""" returns the qconfig tuple for qconfig:
- (activation_dtype, weight_dtype, activation_compute_dtype)
- """
- assert qconfig is not None
- activation = qconfig.activation()
- weight = qconfig.weight()
- compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
- return (activation.dtype, weight.dtype, compute_dtype)
- def get_quant_type(qconfig):
- assert qconfig is not None
- activation = qconfig.activation()
- weight = qconfig.weight()
- static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2]
- if weight.dtype in static_dtypes:
- if activation.dtype in static_dtypes:
- return QuantType.STATIC
- elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes:
- return QuantType.DYNAMIC
- else:
- return QuantType.WEIGHT_ONLY
- if weight.dtype == torch.float16:
- if activation.dtype == torch.float:
- return QuantType.DYNAMIC
- elif activation.dtype == torch.float16:
- return QuantType.STATIC
- raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
- "weight({})".format(activation.dtype, weight.dtype))
- def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
- """ Checks if the given minimum and maximum values are valid, meaning that
- they exist and the min value is less than the max value.
- """
- if min_val.numel() == 0 or max_val.numel() == 0:
- warnings.warn(
- "must run observer before calling calculate_qparams. " +
- "Returning default values."
- )
- return False
- if min_val.dim() == 0 or max_val.dim() == 0:
- if min_val == float("inf") and max_val == float("-inf"):
- warnings.warn(
- "must run observer before calling calculate_qparams. " +
- "Returning default values."
- )
- return False
- assert min_val <= max_val, "min {} should be less than max {}".format(
- min_val, max_val
- )
- else:
- assert torch.all(
- min_val <= max_val
- ), "min {} should be less than max {}".format(min_val, max_val)
- return True
- def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
- reduce_range: bool) -> Tuple[int, int]:
- r"""Calculates actual qmin and qmax based on the quantization range,
- observer datatype and if range is reduced.
- """
- # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
- if has_customized_qrange:
- # This initialization here is to be resolve TorchScript compilation issues and allow
- # using of refinement to decouple initial_qmin and initial_qmax from quantization range.
- # The actual values of initial_qmin and initial_qmax will be reset below.
- if dtype == torch.qint32:
- initial_quant_min, initial_quant_max = 0, 2**31 - 1
- else:
- initial_quant_min, initial_quant_max = 0, 255
- # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
- # attribute from Optional valid integers for use, based on TorchScript's requirements.
- custom_quant_min, custom_quant_max = quant_min, quant_max
- if custom_quant_min is not None and custom_quant_max is not None:
- initial_quant_min, initial_quant_max = (
- custom_quant_min,
- custom_quant_max,
- )
- qrange_len = initial_quant_max - initial_quant_min + 1
- if dtype == torch.qint8:
- assert (
- 0 < qrange_len <= 256
- ), "quantization range should be positive and not exceed the maximum bit range (=256)."
- elif dtype == torch.qint32:
- assert (
- 0 < qrange_len <= 2**31
- ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
- if reduce_range:
- quant_min, quant_max = quant_min // 2, quant_max // 2
- else:
- # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
- if dtype == torch.qint8:
- if reduce_range:
- quant_min, quant_max = -64, 63
- else:
- quant_min, quant_max = -128, 127
- elif dtype == torch.quint8:
- if reduce_range:
- quant_min, quant_max = 0, 127
- else:
- quant_min, quant_max = 0, 255
- elif dtype == torch.qint32:
- quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
- else:
- quant_min, quant_max = 0, 15
- return quant_min, quant_max
- def _parent_name(target):
- """
- Turn 'foo.bar' into ['foo', 'bar']
- """
- r = target.rsplit('.', 1)
- if len(r) == 1:
- return '', r[0]
- else:
- return r[0], r[1]
- def has_no_children_ignoring_parametrizations(module):
- """
- Checks if module._modules is empty or
- if module is a parametrization, checks that module._modules only has
- the 'parametrizations' module
- """
- if len(module._modules) == 0:
- return True
- elif is_parametrized(module):
- return len(module._modules) == 1 and 'parametrizations' in module._modules
- else:
- return False
|