--- /dev/null
+from torch.testing._internal.common_utils import TestCase
+
+import importlib
+from typing import List
+
+
+class AOMigrationTestCase(TestCase):
+ def _test_package_import(self, package_name: str):
+ r"""Tests the module import by making sure that all the internals match
+ (except the dunder methods)."""
+ old_module = importlib.import_module(f'torch.quantization.{package_name}')
+ new_module = importlib.import_module(f'torch.ao.quantization.{package_name}')
+ old_module_dir = set(dir(old_module))
+ new_module_dir = set(dir(new_module))
+ # Remove magic modules from checking in subsets
+ for el in list(old_module_dir):
+ if el[:2] == '__' and el[-2:] == '__':
+ old_module_dir.remove(el)
+ assert (old_module_dir <= new_module_dir), \
+ f"Importing {old_module} vs. {new_module} does not match: " \
+ f"{old_module_dir - new_module_dir}"
+
+ def _test_function_import(self, package_name: str, function_list: List[str]):
+ r"""Tests individual function list import by comparing the functions
+ and their hashes."""
+ old_location = importlib.import_module(f'torch.quantization.{package_name}')
+ new_location = importlib.import_module(f'torch.ao.quantization.{package_name}')
+ for fn_name in function_list:
+ old_function = getattr(old_location, fn_name)
+ new_function = getattr(new_location, fn_name)
+ assert old_function == new_function, f"Functions don't match: {fn_name}"
+ assert hash(old_function) == hash(new_function), \
+ f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
+ f"{new_function}({hash(new_function)})"
+
+
+class TestAOMigrationQuantizePy(AOMigrationTestCase):
+ def test_package_import(self):
+ self._test_package_import('quantize')
+
+ def test_function_import(self):
+ function_list = [
+ '_convert',
+ '_observer_forward_hook',
+ '_propagate_qconfig_helper',
+ '_remove_activation_post_process',
+ '_remove_qconfig',
+ 'add_observer_',
+ 'add_quant_dequant',
+ 'convert',
+ 'get_observer_dict',
+ 'get_unique_devices_',
+ 'is_activation_post_process',
+ 'prepare',
+ 'prepare_qat',
+ 'propagate_qconfig_',
+ 'quantize',
+ 'quantize_dynamic',
+ 'quantize_qat',
+ 'register_activation_post_process_hook',
+ 'swap_module',
+ ]
+ self._test_function_import('quantize', function_list)
from quantization.jit.test_fusion_passes import TestFusionPasses # noqa: F401
from quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized # noqa: F401
+# AO Migration tests
+from quantization.ao_migration.test_quantize import TestAOMigrationQuantizePy # noqa: F401
+
if __name__ == '__main__':
run_tests()
--- /dev/null
+import copy
+import itertools
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.quantized as nnq
+from torch.nn.intrinsic import _FusedModule
+
+from torch.quantization.quantization_mappings import (
+ get_default_dynamic_quant_module_mappings,
+ get_default_static_quant_module_mappings,
+ get_default_qat_module_mappings,
+ get_default_qconfig_propagation_list,
+ no_observer_set,
+ _has_special_act_post_process,
+ _get_special_act_post_process,
+)
+
+from torch.quantization.stubs import DeQuantStub, QuantWrapper
+from torch.quantization.qconfig import (
+ add_module_to_qconfig_obs_ctr,
+ default_dynamic_qconfig,
+ float16_dynamic_qconfig,
+ float_qparams_weight_only_qconfig)
+
+def is_activation_post_process(module):
+ return (isinstance(module, torch.quantization.ObserverBase) or
+ isinstance(module, torch.quantization.FakeQuantizeBase))
+
+def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
+ qconfig_parent=None, prefix=''):
+ r"""This is a helper function for `propagate_qconfig_`
+
+ Args:
+ module: input module
+ qconfig_dict: dictionary that maps from name of submodule to quantization
+ configuration
+ allow_list: list of quantizable modules
+ qconfig_parent: quantization config of parent module, we will fallback to
+ this config when there is no specified config for current
+ module
+ prefix: corresponding prefix of the current module, used as key in
+ qconfig_dict
+
+ Return:
+ None, module is modified inplace with qconfig attached
+ """
+ # TODO: Add test
+ if allow_list is None:
+ allow_list = get_default_qconfig_propagation_list()
+
+ module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
+ module_qconfig = qconfig_dict.get(prefix, module_qconfig)
+ module_qconfig = getattr(module, 'qconfig', module_qconfig)
+
+ torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)
+
+ qconfig_with_device_check = add_module_to_qconfig_obs_ctr(module_qconfig, module)
+ module.qconfig = qconfig_with_device_check
+
+ for name, child in module.named_children():
+ module_prefix = prefix + '.' + name if prefix else name
+ _propagate_qconfig_helper(child, qconfig_dict, allow_list,
+ qconfig_with_device_check, module_prefix)
+
+# TODO(jerryzh): expose allow_list
+def propagate_qconfig_(module, qconfig_dict=None, allow_list=None):
+ r"""Propagate qconfig through the module hierarchy and assign `qconfig`
+ attribute on each leaf module
+
+ Args:
+ module: input module
+ qconfig_dict: dictionary that maps from name or type of submodule to
+ quantization configuration, qconfig applies to all submodules of a
+ given module unless qconfig for the submodules are specified (when
+ the submodule already has qconfig attribute)
+
+ Return:
+ None, module is modified inplace with qconfig attached
+ """
+ if qconfig_dict is None:
+ qconfig_dict = {}
+ _propagate_qconfig_helper(module, qconfig_dict, allow_list)
+
+def _observer_forward_hook(self, input, output):
+ r"""Forward hook that calls observer on the output
+ """
+ return self.activation_post_process(output)
+
+def register_activation_post_process_hook(module):
+ assert hasattr(module, 'activation_post_process'), \
+ 'Expect activation_post_process attribut already attached to the module'
+ return module.register_forward_hook(_observer_forward_hook)
+
+def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
+ r"""Add observer for the leaf child of the module.
+
+ This function insert observer module to all leaf child module that
+ has a valid qconfig attribute.
+
+ Args:
+ module: input module with qconfig attributes for all the leaf modules that we want to quantize
+ device: parent device, if any
+ non_leaf_module_list: list of non-leaf modules we want to add observer
+
+ Return:
+ None, module is modified inplace with added observer modules and forward_hooks
+ """
+ if qconfig_propagation_list is None:
+ qconfig_propagation_list = get_default_qconfig_propagation_list()
+
+ if custom_module_class_mapping is None:
+ custom_module_class_mapping = {}
+
+ # respect device affinity when adding observers
+ if device is None:
+ devices = get_unique_devices_(module)
+ assert len(devices) <= 1, (
+ "add_observer_ only works with cpu or single-device CUDA modules, "
+ "but got devices {}".format(devices)
+ )
+ device = next(iter(devices)) if len(devices) > 0 else None
+
+ def get_activation_post_process(qconfig, device, special_act_post_process=None):
+ activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
+ if device is not None:
+ activation.to(device)
+ return activation
+
+ def needs_observation(m):
+ return hasattr(m, 'qconfig') and m.qconfig is not None
+
+ def insert_activation_post_process(m, special_act_post_process=None):
+ """ Adds an activation post process module and register
+ a post hook that calls the module
+ """
+ # We don't insert observer/fake_quantize for DeQuantStub
+ if needs_observation(m) and not isinstance(m, DeQuantStub):
+ # observer and hook will be gone after we swap the module
+ m.add_module('activation_post_process', get_activation_post_process(
+ m.qconfig, device, special_act_post_process))
+ # Register observer as the first entry in the hook list
+ # All post forward hooks are preserved and will be executed after the observer before convert
+ handle = register_activation_post_process_hook(m)
+ m._forward_hooks.move_to_end(handle.id, last=False)
+
+ for name, child in module.named_children():
+ if type(child) in [nnq.FloatFunctional, nnq.QFunctional]:
+ if needs_observation(child):
+ child.activation_post_process = get_activation_post_process(child.qconfig, device)
+ elif isinstance(child, _FusedModule):
+ # activation_post_process are now added directly to nn.Sequentail/_FusedModule
+ if needs_observation(child):
+ insert_activation_post_process(child)
+ elif _has_special_act_post_process(child):
+ special_act_post_process = _get_special_act_post_process(child)
+ insert_activation_post_process(child, special_act_post_process)
+ elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
+ if needs_observation(child):
+ insert_activation_post_process(child)
+ elif needs_observation(child) and type(child) in custom_module_class_mapping:
+ observed_child = custom_module_class_mapping[type(child)].from_float(child)
+ setattr(module, name, observed_child)
+ # TODO: These are the modules that cannot be observed
+ # Once there are more, we should move them to a separate list
+ if custom_module_class_mapping[type(child)] not in no_observer_set():
+ insert_activation_post_process(observed_child)
+ else:
+ add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
+
+ # Insert observers only for leaf nodes, note that this observer is for
+ # the output of the module, for input QuantStub will observe them
+ if len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
+ and type(module) in qconfig_propagation_list:
+ insert_activation_post_process(module)
+
+def get_unique_devices_(module):
+ return {p.device for p in module.parameters()} | \
+ {p.device for p in module.buffers()}
+
+def add_quant_dequant(module):
+ r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
+ Note that this function will modify the children of module inplace and it
+ can return a new module which wraps the input module as well.
+
+ Args:
+ module: input module with qconfig attributes for all the leaf modules
+ that we want to quantize
+
+ Return:
+ Either the inplace modified module with submodules wrapped in
+ `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
+ wraps the input module, the latter case only happens when the input
+ module is a leaf module and we want to quantize it.
+ """
+ if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
+ return QuantWrapper(module)
+
+ for name, child in module.named_children():
+ module._modules[name] = add_quant_dequant(child)
+ return module
+
+def prepare(model, inplace=False, allow_list=None,
+ observer_non_leaf_module_list=None,
+ prepare_custom_config_dict=None):
+ r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
+
+ Quantization configuration should be assigned preemptively
+ to individual submodules in `.qconfig` attribute.
+
+ The model will be attached with observer or fake quant modules, and qconfig
+ will be propagated.
+
+ Args:
+ `model`: input model to be modified in-place
+ `inplace`: carry out model transformations in-place, the original module is mutated
+ `allow_list`: list of quantizable modules
+ `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
+ `prepare_custom_config_dict`: customization configuration dictionary for prepare function
+
+ .. code-block:: python
+
+ # Example of prepare_custom_config_dict:
+ prepare_custom_config_dict = {
+ # user will manually define the corresponding observed
+ # module class which has a from_float class method that converts
+ # float custom module to observed custom module
+ "float_to_observed_custom_module_class": {
+ CustomModule: ObservedCustomModule
+ }
+ }
+
+ """
+ torch._C._log_api_usage_once("quantization_api.quantize.prepare")
+ if prepare_custom_config_dict is None:
+ prepare_custom_config_dict = {}
+ custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
+
+ if not inplace:
+ model = copy.deepcopy(model)
+
+ # TODO: remove allow_list
+ qconfig_propagation_list = allow_list
+ if qconfig_propagation_list is None:
+ qconfig_propagation_list = get_default_qconfig_propagation_list()
+ propagate_qconfig_(model, qconfig_dict=None)
+
+ # sanity check common API misusage
+ if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
+ warnings.warn("None of the submodule got qconfig applied. Make sure you "
+ "passed correct configuration through `qconfig_dict` or "
+ "by assigning the `.qconfig` attribute directly on submodules")
+
+ add_observer_(
+ model, qconfig_propagation_list, observer_non_leaf_module_list,
+ custom_module_class_mapping=custom_module_class_mapping)
+ return model
+
+def _remove_activation_post_process(module):
+ # TODO: maybe we should change activation_post_process to _activation_post_process
+ # to prevent it from being used by user
+ if hasattr(module, 'activation_post_process') and \
+ is_activation_post_process(module.activation_post_process):
+ delattr(module, 'activation_post_process')
+
+ # remove activation_post_proceess hook
+ handle_ids_to_remove = set()
+ for handle_id, hook_fn in module._forward_hooks.items():
+ if hook_fn is _observer_forward_hook:
+ handle_ids_to_remove.add(handle_id)
+ for handle_id in handle_ids_to_remove:
+ module._forward_hooks.pop(handle_id)
+
+# TODO: rename to something more general
+def _remove_qconfig(module):
+ r"""Clean up the qconfig left in the module so that new qconfig can be
+ propagated.
+
+ Args:
+ module: module to be cleaned up
+ """
+ for child in module.children():
+ _remove_qconfig(child)
+
+ if hasattr(module, "qconfig"):
+ del module.qconfig
+
+ _remove_activation_post_process(module)
+
+def quantize(model, run_fn, run_args, mapping=None, inplace=False):
+ r"""Quantize the input float model with post training static quantization.
+
+ First it will prepare the model for calibration, then it calls
+ `run_fn` which will run the calibration step, after that we will
+ convert the model to a quantized model.
+
+ Args:
+ model: input float model
+ run_fn: a calibration function for calibrating the prepared model
+ run_args: positional arguments for `run_fn`
+ inplace: carry out model transformations in-place, the original module is mutated
+ mapping: correspondence between original module types and quantized counterparts
+
+ Return:
+ Quantized model.
+ """
+ torch._C._log_api_usage_once("quantization_api.quantize.quantize")
+ if mapping is None:
+ mapping = get_default_static_quant_module_mappings()
+ if not inplace:
+ model = copy.deepcopy(model)
+ model.eval()
+ prepare(model, inplace=True)
+ run_fn(model, *run_args)
+ convert(model, mapping, inplace=True)
+ return model
+
+def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
+ mapping=None, inplace=False):
+ r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
+
+ Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
+
+ For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
+ by default is performed for layers with large weights size - i.e. Linear and RNN variants.
+
+ Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
+ If `qconfig` is provided, the `dtype` argument is ignored.
+
+ Args:
+ model: input model
+ qconfig_spec: Either:
+
+ - A dictionary that maps from name or type of submodule to quantization
+ configuration, qconfig applies to all submodules of a given
+ module unless qconfig for the submodules are specified (when the
+ submodule already has qconfig attribute). Entries in the dictionary
+ need to be QConfigDynamic instances.
+
+ - A set of types and/or submodule names to apply dynamic quantization to,
+ in which case the `dtype` argument is used to specify the bit-width
+
+ inplace: carry out model transformations in-place, the original module is mutated
+ mapping: maps type of a submodule to a type of corresponding dynamically quantized version
+ with which the submodule needs to be replaced
+
+ """
+ torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
+ if qconfig_spec is None:
+ if dtype == torch.qint8:
+ qconfig_spec = {
+ nn.Linear : default_dynamic_qconfig,
+ nn.LSTM : default_dynamic_qconfig,
+ nn.GRU : default_dynamic_qconfig,
+ nn.LSTMCell : default_dynamic_qconfig,
+ nn.RNNCell : default_dynamic_qconfig,
+ nn.GRUCell : default_dynamic_qconfig,
+ }
+ elif dtype == torch.float16:
+ qconfig_spec = {
+ nn.Linear : float16_dynamic_qconfig,
+ nn.LSTM : float16_dynamic_qconfig,
+ nn.GRU : float16_dynamic_qconfig,
+ nn.LSTMCell : float16_dynamic_qconfig,
+ nn.RNNCell : float16_dynamic_qconfig,
+ nn.GRUCell : float16_dynamic_qconfig,
+ }
+ elif dtype == torch.quint8:
+ qconfig_spec = {
+ nn.EmbeddingBag : float_qparams_weight_only_qconfig,
+ }
+ else:
+ raise ValueError(
+ "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
+ elif isinstance(qconfig_spec, set):
+ if dtype is torch.qint8:
+ default_qconfig = default_dynamic_qconfig
+ elif dtype is torch.float16:
+ default_qconfig = float16_dynamic_qconfig
+ elif dtype is torch.quint8:
+ default_qconfig = float_qparams_weight_only_qconfig
+ else:
+ raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
+ qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
+
+ if mapping is None:
+ mapping = get_default_dynamic_quant_module_mappings()
+
+ if not inplace:
+ model = copy.deepcopy(model)
+ model.eval()
+ propagate_qconfig_(model, qconfig_spec)
+ convert(model, mapping, inplace=True)
+ return model
+
+def prepare_qat(model, mapping=None, inplace=False):
+ r"""
+ Prepares a copy of the model for quantization calibration or
+ quantization-aware training and converts it to quantized version.
+
+ Quantization configuration should be assigned preemptively
+ to individual submodules in `.qconfig` attribute.
+
+ Args:
+ model: input model to be modified in-place
+ mapping: dictionary that maps float modules to quantized modules to be
+ replaced.
+ inplace: carry out model transformations in-place, the original module
+ is mutated
+ """
+ torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
+ if mapping is None:
+ mapping = get_default_qat_module_mappings()
+
+ if not inplace:
+ model = copy.deepcopy(model)
+
+ propagate_qconfig_(model, qconfig_dict=None)
+ convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
+ prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
+ return model
+
+def quantize_qat(model, run_fn, run_args, inplace=False):
+ r"""Do quantization aware training and output a quantized model
+
+ Args:
+ model: input model
+ run_fn: a function for evaluating the prepared model, can be a
+ function that simply runs the prepared model or a training
+ loop
+ run_args: positional arguments for `run_fn`
+
+ Return:
+ Quantized model.
+ """
+ torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
+ if not inplace:
+ model = copy.deepcopy(model)
+ model.train()
+ prepare_qat(model, inplace=True)
+ run_fn(model, *run_args)
+ convert(model, inplace=True)
+ return model
+
+def convert(
+ module, mapping=None, inplace=False, remove_qconfig=True,
+ convert_custom_config_dict=None):
+ r"""Converts submodules in input module to a different module according to `mapping`
+ by calling `from_float` method on the target module class. And remove qconfig at the
+ end if remove_qconfig is set to True.
+
+ Args:
+ `module`: prepared and calibrated module
+ `mapping`: a dictionary that maps from source module type to target
+ module type, can be overwritten to allow swapping user defined
+ Modules
+ `inplace`: carry out model transformations in-place, the original module
+ is mutated
+ `convert_custom_config_dict`: custom configuration dictionary for convert function
+
+ .. code-block:: python
+
+ # Example of convert_custom_config_dict:
+ convert_custom_config_dict = {
+ # user will manually define the corresponding quantized
+ # module class which has a from_observed class method that converts
+ # observed custom module to quantized custom module
+ "observed_to_quantized_custom_module_class": {
+ ObservedCustomModule: QuantizedCustomModule
+ }
+ }
+
+ """
+ torch._C._log_api_usage_once("quantization_api.quantize.convert")
+ if not inplace:
+ module = copy.deepcopy(module)
+ _convert(
+ module, mapping, inplace=True,
+ convert_custom_config_dict=convert_custom_config_dict)
+ if remove_qconfig:
+ _remove_qconfig(module)
+ return module
+
+def _convert(
+ module, mapping=None, inplace=False,
+ convert_custom_config_dict=None):
+ r"""Converts submodules in input module to a different module according to `mapping`
+ by calling `from_float` method on the target module class
+
+ Args:
+ module: input module
+ mapping: a dictionary that maps from source module type to target
+ module type, can be overwritten to allow swapping user defined
+ Modules
+ inplace: carry out model transformations in-place, the original module
+ is mutated
+
+ """
+ if mapping is None:
+ mapping = get_default_static_quant_module_mappings()
+ if convert_custom_config_dict is None:
+ convert_custom_config_dict = {}
+ custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
+
+ if not inplace:
+ module = copy.deepcopy(module)
+ reassign = {}
+ for name, mod in module.named_children():
+ # both fused modules and observed custom modules are
+ # swapped as one unit
+ if not isinstance(mod, _FusedModule) and \
+ type(mod) not in custom_module_class_mapping:
+ _convert(mod, mapping, True, # inplace
+ convert_custom_config_dict)
+ reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
+
+ for key, value in reassign.items():
+ module._modules[key] = value
+
+ return module
+
+def swap_module(mod, mapping, custom_module_class_mapping):
+ r"""Swaps the module if it has a quantized counterpart and it has an
+ `observer` attached.
+
+ Args:
+ mod: input module
+ mapping: a dictionary that maps from nn module to nnq module
+
+ Return:
+ The corresponding quantized module of `mod`
+ """
+ new_mod = mod
+ if hasattr(mod, 'qconfig') and mod.qconfig is not None:
+ swapped = False
+ if type(mod) in custom_module_class_mapping:
+ new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
+ swapped = True
+ elif type(mod) in mapping:
+ new_mod = mapping[type(mod)].from_float(mod)
+ swapped = True
+
+ if swapped:
+ # Preserve module's pre forward hooks. They'll be called on quantized input
+ for pre_hook_fn in mod._forward_pre_hooks.values():
+ new_mod.register_forward_pre_hook(pre_hook_fn)
+ # Preserve module's post forward hooks except _observer_forward_hook
+ # After convert they'll work with quantized output
+ for hook_fn in mod._forward_hooks.values():
+ if hook_fn is not _observer_forward_hook:
+ new_mod.register_forward_hook(hook_fn)
+
+ # respect device affinity when swapping modules
+ devices = get_unique_devices_(mod)
+ assert len(devices) <= 1, (
+ "swap_module only works with cpu or single-device CUDA modules, "
+ "but got devices {}".format(devices)
+ )
+ device = next(iter(devices)) if len(devices) > 0 else None
+ if device:
+ new_mod.to(device)
+ return new_mod
+
+def get_observer_dict(mod, target_dict, prefix=""):
+ r"""Traverse the modules and save all observers into dict.
+ This is mainly used for quantization accuracy debug
+ Args:
+ mod: the top module we want to save all observers
+ prefix: the prefix for the current module
+ target_dict: the dictionary used to save all the observers
+ """
+ def get_prefix(prefix):
+ return prefix if prefix == "" else prefix + '.'
+
+ if hasattr(mod, 'activation_post_process'):
+ target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
+ for name, child in mod.named_children():
+ module_prefix = get_prefix(prefix) + name if prefix else name
+ get_observer_dict(child, target_dict, module_prefix)
WEIGHT_INDEX_DICT,
)
-from ..quantize import (
+from torch.ao.quantization.quantize import (
_remove_qconfig,
is_activation_post_process,
)
get_default_qat_module_mappings,
)
-from ..quantize import (
+from torch.ao.quantization.quantize import (
is_activation_post_process,
convert
)
get_qparam_dict,
)
-from ..quantize import (
+from torch.ao.quantization.quantize import (
is_activation_post_process,
)
import torch
import torch.nn as nn
from ..utils import is_per_tensor, is_per_channel
-from ..quantize import is_activation_post_process
+from torch.ao.quantization.quantize import is_activation_post_process
from torch.fx import GraphModule, map_arg
-import copy
-import itertools
-import warnings
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/quantization/quantize.py`, while adding an import statement
+here.
+"""
import torch
-import torch.nn as nn
-import torch.nn.quantized as nnq
-from torch.nn.intrinsic import _FusedModule
-from .quantization_mappings import (
- get_default_dynamic_quant_module_mappings,
- get_default_static_quant_module_mappings,
- get_default_qat_module_mappings,
- get_default_qconfig_propagation_list,
- no_observer_set,
- _has_special_act_post_process,
- _get_special_act_post_process,
-)
-
-from .stubs import DeQuantStub, QuantWrapper
-from .qconfig import (
- add_module_to_qconfig_obs_ctr,
- default_dynamic_qconfig,
- float16_dynamic_qconfig,
- float_qparams_weight_only_qconfig)
-
-def is_activation_post_process(module):
- return (isinstance(module, torch.quantization.ObserverBase) or
- isinstance(module, torch.quantization.FakeQuantizeBase))
-
-def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
- qconfig_parent=None, prefix=''):
- r"""This is a helper function for `propagate_qconfig_`
-
- Args:
- module: input module
- qconfig_dict: dictionary that maps from name of submodule to quantization
- configuration
- allow_list: list of quantizable modules
- qconfig_parent: quantization config of parent module, we will fallback to
- this config when there is no specified config for current
- module
- prefix: corresponding prefix of the current module, used as key in
- qconfig_dict
-
- Return:
- None, module is modified inplace with qconfig attached
- """
- # TODO: Add test
- if allow_list is None:
- allow_list = get_default_qconfig_propagation_list()
-
- module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
- module_qconfig = qconfig_dict.get(prefix, module_qconfig)
- module_qconfig = getattr(module, 'qconfig', module_qconfig)
-
- torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)
-
- qconfig_with_device_check = add_module_to_qconfig_obs_ctr(module_qconfig, module)
- module.qconfig = qconfig_with_device_check
-
- for name, child in module.named_children():
- module_prefix = prefix + '.' + name if prefix else name
- _propagate_qconfig_helper(child, qconfig_dict, allow_list,
- qconfig_with_device_check, module_prefix)
-
-# TODO(jerryzh): expose allow_list
-def propagate_qconfig_(module, qconfig_dict=None, allow_list=None):
- r"""Propagate qconfig through the module hierarchy and assign `qconfig`
- attribute on each leaf module
-
- Args:
- module: input module
- qconfig_dict: dictionary that maps from name or type of submodule to
- quantization configuration, qconfig applies to all submodules of a
- given module unless qconfig for the submodules are specified (when
- the submodule already has qconfig attribute)
-
- Return:
- None, module is modified inplace with qconfig attached
- """
- if qconfig_dict is None:
- qconfig_dict = {}
- _propagate_qconfig_helper(module, qconfig_dict, allow_list)
-
-def _observer_forward_hook(self, input, output):
- r"""Forward hook that calls observer on the output
- """
- return self.activation_post_process(output)
-
-def register_activation_post_process_hook(module):
- assert hasattr(module, 'activation_post_process'), \
- 'Expect activation_post_process attribut already attached to the module'
- return module.register_forward_hook(_observer_forward_hook)
-
-def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
- r"""Add observer for the leaf child of the module.
-
- This function insert observer module to all leaf child module that
- has a valid qconfig attribute.
-
- Args:
- module: input module with qconfig attributes for all the leaf modules that we want to quantize
- device: parent device, if any
- non_leaf_module_list: list of non-leaf modules we want to add observer
-
- Return:
- None, module is modified inplace with added observer modules and forward_hooks
- """
- if qconfig_propagation_list is None:
- qconfig_propagation_list = get_default_qconfig_propagation_list()
-
- if custom_module_class_mapping is None:
- custom_module_class_mapping = {}
-
- # respect device affinity when adding observers
- if device is None:
- devices = get_unique_devices_(module)
- assert len(devices) <= 1, (
- "add_observer_ only works with cpu or single-device CUDA modules, "
- "but got devices {}".format(devices)
- )
- device = next(iter(devices)) if len(devices) > 0 else None
-
- def get_activation_post_process(qconfig, device, special_act_post_process=None):
- activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
- if device is not None:
- activation.to(device)
- return activation
-
- def needs_observation(m):
- return hasattr(m, 'qconfig') and m.qconfig is not None
-
- def insert_activation_post_process(m, special_act_post_process=None):
- """ Adds an activation post process module and register
- a post hook that calls the module
- """
- # We don't insert observer/fake_quantize for DeQuantStub
- if needs_observation(m) and not isinstance(m, DeQuantStub):
- # observer and hook will be gone after we swap the module
- m.add_module('activation_post_process', get_activation_post_process(
- m.qconfig, device, special_act_post_process))
- # Register observer as the first entry in the hook list
- # All post forward hooks are preserved and will be executed after the observer before convert
- handle = register_activation_post_process_hook(m)
- m._forward_hooks.move_to_end(handle.id, last=False)
-
- for name, child in module.named_children():
- if type(child) in [nnq.FloatFunctional, nnq.QFunctional]:
- if needs_observation(child):
- child.activation_post_process = get_activation_post_process(child.qconfig, device)
- elif isinstance(child, _FusedModule):
- # activation_post_process are now added directly to nn.Sequentail/_FusedModule
- if needs_observation(child):
- insert_activation_post_process(child)
- elif _has_special_act_post_process(child):
- special_act_post_process = _get_special_act_post_process(child)
- insert_activation_post_process(child, special_act_post_process)
- elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
- if needs_observation(child):
- insert_activation_post_process(child)
- elif needs_observation(child) and type(child) in custom_module_class_mapping:
- observed_child = custom_module_class_mapping[type(child)].from_float(child)
- setattr(module, name, observed_child)
- # TODO: These are the modules that cannot be observed
- # Once there are more, we should move them to a separate list
- if custom_module_class_mapping[type(child)] not in no_observer_set():
- insert_activation_post_process(observed_child)
- else:
- add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
-
- # Insert observers only for leaf nodes, note that this observer is for
- # the output of the module, for input QuantStub will observe them
- if len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
- and type(module) in qconfig_propagation_list:
- insert_activation_post_process(module)
-
-def get_unique_devices_(module):
- return {p.device for p in module.parameters()} | \
- {p.device for p in module.buffers()}
-
-def add_quant_dequant(module):
- r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
- Note that this function will modify the children of module inplace and it
- can return a new module which wraps the input module as well.
-
- Args:
- module: input module with qconfig attributes for all the leaf modules
- that we want to quantize
-
- Return:
- Either the inplace modified module with submodules wrapped in
- `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
- wraps the input module, the latter case only happens when the input
- module is a leaf module and we want to quantize it.
- """
- if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
- return QuantWrapper(module)
-
- for name, child in module.named_children():
- module._modules[name] = add_quant_dequant(child)
- return module
-
-def prepare(model, inplace=False, allow_list=None,
- observer_non_leaf_module_list=None,
- prepare_custom_config_dict=None):
- r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
-
- Quantization configuration should be assigned preemptively
- to individual submodules in `.qconfig` attribute.
-
- The model will be attached with observer or fake quant modules, and qconfig
- will be propagated.
-
- Args:
- `model`: input model to be modified in-place
- `inplace`: carry out model transformations in-place, the original module is mutated
- `allow_list`: list of quantizable modules
- `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
- `prepare_custom_config_dict`: customization configuration dictionary for prepare function
-
- .. code-block:: python
-
- # Example of prepare_custom_config_dict:
- prepare_custom_config_dict = {
- # user will manually define the corresponding observed
- # module class which has a from_float class method that converts
- # float custom module to observed custom module
- "float_to_observed_custom_module_class": {
- CustomModule: ObservedCustomModule
- }
- }
-
- """
- torch._C._log_api_usage_once("quantization_api.quantize.prepare")
- if prepare_custom_config_dict is None:
- prepare_custom_config_dict = {}
- custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
-
- if not inplace:
- model = copy.deepcopy(model)
-
- # TODO: remove allow_list
- qconfig_propagation_list = allow_list
- if qconfig_propagation_list is None:
- qconfig_propagation_list = get_default_qconfig_propagation_list()
- propagate_qconfig_(model, qconfig_dict=None)
-
- # sanity check common API misusage
- if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
- warnings.warn("None of the submodule got qconfig applied. Make sure you "
- "passed correct configuration through `qconfig_dict` or "
- "by assigning the `.qconfig` attribute directly on submodules")
-
- add_observer_(
- model, qconfig_propagation_list, observer_non_leaf_module_list,
- custom_module_class_mapping=custom_module_class_mapping)
- return model
-
-def _remove_activation_post_process(module):
- # TODO: maybe we should change activation_post_process to _activation_post_process
- # to prevent it from being used by user
- if hasattr(module, 'activation_post_process') and \
- is_activation_post_process(module.activation_post_process):
- delattr(module, 'activation_post_process')
-
- # remove activation_post_proceess hook
- handle_ids_to_remove = set()
- for handle_id, hook_fn in module._forward_hooks.items():
- if hook_fn is _observer_forward_hook:
- handle_ids_to_remove.add(handle_id)
- for handle_id in handle_ids_to_remove:
- module._forward_hooks.pop(handle_id)
-
-# TODO: rename to something more general
-def _remove_qconfig(module):
- r"""Clean up the qconfig left in the module so that new qconfig can be
- propagated.
-
- Args:
- module: module to be cleaned up
- """
- for child in module.children():
- _remove_qconfig(child)
-
- if hasattr(module, "qconfig"):
- del module.qconfig
-
- _remove_activation_post_process(module)
-
-def quantize(model, run_fn, run_args, mapping=None, inplace=False):
- r"""Quantize the input float model with post training static quantization.
-
- First it will prepare the model for calibration, then it calls
- `run_fn` which will run the calibration step, after that we will
- convert the model to a quantized model.
-
- Args:
- model: input float model
- run_fn: a calibration function for calibrating the prepared model
- run_args: positional arguments for `run_fn`
- inplace: carry out model transformations in-place, the original module is mutated
- mapping: correspondence between original module types and quantized counterparts
-
- Return:
- Quantized model.
- """
- torch._C._log_api_usage_once("quantization_api.quantize.quantize")
- if mapping is None:
- mapping = get_default_static_quant_module_mappings()
- if not inplace:
- model = copy.deepcopy(model)
- model.eval()
- prepare(model, inplace=True)
- run_fn(model, *run_args)
- convert(model, mapping, inplace=True)
- return model
-
-def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
- mapping=None, inplace=False):
- r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
-
- Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
-
- For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
- by default is performed for layers with large weights size - i.e. Linear and RNN variants.
-
- Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
- If `qconfig` is provided, the `dtype` argument is ignored.
-
- Args:
- model: input model
- qconfig_spec: Either:
-
- - A dictionary that maps from name or type of submodule to quantization
- configuration, qconfig applies to all submodules of a given
- module unless qconfig for the submodules are specified (when the
- submodule already has qconfig attribute). Entries in the dictionary
- need to be QConfigDynamic instances.
-
- - A set of types and/or submodule names to apply dynamic quantization to,
- in which case the `dtype` argument is used to specify the bit-width
-
- inplace: carry out model transformations in-place, the original module is mutated
- mapping: maps type of a submodule to a type of corresponding dynamically quantized version
- with which the submodule needs to be replaced
-
- """
- torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
- if qconfig_spec is None:
- if dtype == torch.qint8:
- qconfig_spec = {
- nn.Linear : default_dynamic_qconfig,
- nn.LSTM : default_dynamic_qconfig,
- nn.GRU : default_dynamic_qconfig,
- nn.LSTMCell : default_dynamic_qconfig,
- nn.RNNCell : default_dynamic_qconfig,
- nn.GRUCell : default_dynamic_qconfig,
- }
- elif dtype == torch.float16:
- qconfig_spec = {
- nn.Linear : float16_dynamic_qconfig,
- nn.LSTM : float16_dynamic_qconfig,
- nn.GRU : float16_dynamic_qconfig,
- nn.LSTMCell : float16_dynamic_qconfig,
- nn.RNNCell : float16_dynamic_qconfig,
- nn.GRUCell : float16_dynamic_qconfig,
- }
- elif dtype == torch.quint8:
- qconfig_spec = {
- nn.EmbeddingBag : float_qparams_weight_only_qconfig,
- }
- else:
- raise ValueError(
- "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
- elif isinstance(qconfig_spec, set):
- if dtype is torch.qint8:
- default_qconfig = default_dynamic_qconfig
- elif dtype is torch.float16:
- default_qconfig = float16_dynamic_qconfig
- elif dtype is torch.quint8:
- default_qconfig = float_qparams_weight_only_qconfig
- else:
- raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
- qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
-
- if mapping is None:
- mapping = get_default_dynamic_quant_module_mappings()
-
- if not inplace:
- model = copy.deepcopy(model)
- model.eval()
- propagate_qconfig_(model, qconfig_spec)
- convert(model, mapping, inplace=True)
- return model
-
-def prepare_qat(model, mapping=None, inplace=False):
- r"""
- Prepares a copy of the model for quantization calibration or
- quantization-aware training and converts it to quantized version.
-
- Quantization configuration should be assigned preemptively
- to individual submodules in `.qconfig` attribute.
-
- Args:
- model: input model to be modified in-place
- mapping: dictionary that maps float modules to quantized modules to be
- replaced.
- inplace: carry out model transformations in-place, the original module
- is mutated
- """
- torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
- if mapping is None:
- mapping = get_default_qat_module_mappings()
-
- if not inplace:
- model = copy.deepcopy(model)
-
- propagate_qconfig_(model, qconfig_dict=None)
- convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
- prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
- return model
-
-def quantize_qat(model, run_fn, run_args, inplace=False):
- r"""Do quantization aware training and output a quantized model
-
- Args:
- model: input model
- run_fn: a function for evaluating the prepared model, can be a
- function that simply runs the prepared model or a training
- loop
- run_args: positional arguments for `run_fn`
-
- Return:
- Quantized model.
- """
- torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
- if not inplace:
- model = copy.deepcopy(model)
- model.train()
- prepare_qat(model, inplace=True)
- run_fn(model, *run_args)
- convert(model, inplace=True)
- return model
-
-def convert(
- module, mapping=None, inplace=False, remove_qconfig=True,
- convert_custom_config_dict=None):
- r"""Converts submodules in input module to a different module according to `mapping`
- by calling `from_float` method on the target module class. And remove qconfig at the
- end if remove_qconfig is set to True.
-
- Args:
- `module`: prepared and calibrated module
- `mapping`: a dictionary that maps from source module type to target
- module type, can be overwritten to allow swapping user defined
- Modules
- `inplace`: carry out model transformations in-place, the original module
- is mutated
- `convert_custom_config_dict`: custom configuration dictionary for convert function
-
- .. code-block:: python
-
- # Example of convert_custom_config_dict:
- convert_custom_config_dict = {
- # user will manually define the corresponding quantized
- # module class which has a from_observed class method that converts
- # observed custom module to quantized custom module
- "observed_to_quantized_custom_module_class": {
- ObservedCustomModule: QuantizedCustomModule
- }
- }
-
- """
- torch._C._log_api_usage_once("quantization_api.quantize.convert")
- if not inplace:
- module = copy.deepcopy(module)
- _convert(
- module, mapping, inplace=True,
- convert_custom_config_dict=convert_custom_config_dict)
- if remove_qconfig:
- _remove_qconfig(module)
- return module
-
-def _convert(
- module, mapping=None, inplace=False,
- convert_custom_config_dict=None):
- r"""Converts submodules in input module to a different module according to `mapping`
- by calling `from_float` method on the target module class
-
- Args:
- module: input module
- mapping: a dictionary that maps from source module type to target
- module type, can be overwritten to allow swapping user defined
- Modules
- inplace: carry out model transformations in-place, the original module
- is mutated
-
- """
- if mapping is None:
- mapping = get_default_static_quant_module_mappings()
- if convert_custom_config_dict is None:
- convert_custom_config_dict = {}
- custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
-
- if not inplace:
- module = copy.deepcopy(module)
- reassign = {}
- for name, mod in module.named_children():
- # both fused modules and observed custom modules are
- # swapped as one unit
- if not isinstance(mod, _FusedModule) and \
- type(mod) not in custom_module_class_mapping:
- _convert(mod, mapping, True, # inplace
- convert_custom_config_dict)
- reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
-
- for key, value in reassign.items():
- module._modules[key] = value
-
- return module
-
-def swap_module(mod, mapping, custom_module_class_mapping):
- r"""Swaps the module if it has a quantized counterpart and it has an
- `observer` attached.
-
- Args:
- mod: input module
- mapping: a dictionary that maps from nn module to nnq module
-
- Return:
- The corresponding quantized module of `mod`
- """
- new_mod = mod
- if hasattr(mod, 'qconfig') and mod.qconfig is not None:
- swapped = False
- if type(mod) in custom_module_class_mapping:
- new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
- swapped = True
- elif type(mod) in mapping:
- new_mod = mapping[type(mod)].from_float(mod)
- swapped = True
-
- if swapped:
- # Preserve module's pre forward hooks. They'll be called on quantized input
- for pre_hook_fn in mod._forward_pre_hooks.values():
- new_mod.register_forward_pre_hook(pre_hook_fn)
- # Preserve module's post forward hooks except _observer_forward_hook
- # After convert they'll work with quantized output
- for hook_fn in mod._forward_hooks.values():
- if hook_fn is not _observer_forward_hook:
- new_mod.register_forward_hook(hook_fn)
-
- # respect device affinity when swapping modules
- devices = get_unique_devices_(mod)
- assert len(devices) <= 1, (
- "swap_module only works with cpu or single-device CUDA modules, "
- "but got devices {}".format(devices)
- )
- device = next(iter(devices)) if len(devices) > 0 else None
- if device:
- new_mod.to(device)
- return new_mod
-
-def get_observer_dict(mod, target_dict, prefix=""):
- r"""Traverse the modules and save all observers into dict.
- This is mainly used for quantization accuracy debug
- Args:
- mod: the top module we want to save all observers
- prefix: the prefix for the current module
- target_dict: the dictionary used to save all the observers
- """
- def get_prefix(prefix):
- return prefix if prefix == "" else prefix + '.'
-
- if hasattr(mod, 'activation_post_process'):
- target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
- for name, child in mod.named_children():
- module_prefix = get_prefix(prefix) + name if prefix else name
- get_observer_dict(child, target_dict, module_prefix)
+from torch.ao.quantization.quantize import _convert
+from torch.ao.quantization.quantize import _observer_forward_hook
+from torch.ao.quantization.quantize import _propagate_qconfig_helper
+from torch.ao.quantization.quantize import _remove_activation_post_process
+from torch.ao.quantization.quantize import _remove_qconfig
+from torch.ao.quantization.quantize import add_observer_
+from torch.ao.quantization.quantize import add_quant_dequant
+from torch.ao.quantization.quantize import convert
+from torch.ao.quantization.quantize import get_observer_dict
+from torch.ao.quantization.quantize import get_unique_devices_
+from torch.ao.quantization.quantize import is_activation_post_process
+from torch.ao.quantization.quantize import prepare
+from torch.ao.quantization.quantize import prepare_qat
+from torch.ao.quantization.quantize import propagate_qconfig_
+from torch.ao.quantization.quantize import quantize
+from torch.ao.quantization.quantize import quantize_dynamic
+from torch.ao.quantization.quantize import quantize_qat
+from torch.ao.quantization.quantize import register_activation_post_process_hook
+from torch.ao.quantization.quantize import swap_module