[quant] AO migration of the `quantize.py` (resubmission) (#64445)
authorZafar Takhirov <zaf@fb.com>
Wed, 8 Sep 2021 11:57:28 +0000 (04:57 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 11:58:47 +0000 (04:58 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64445

AO Team is migrating the existing torch.quantization into torch.ao.quantization. We are doing it one file at a time to make sure that the internal callsites are updated properly.
This migrates the quantize.py from torch.quantization to torch.ao.quantization.
At this point both locations will be supported. Eventually the torch.quantization will be deprecated.

Test Plan: `buck test mode/dev //caffe2/test:quantization`

Reviewed By: HDCharles

Differential Revision: D30734870

fbshipit-source-id: dc204f3cc46bff2cc81c95159eab9d333b43bb4b

test/quantization/ao_migration/__init__.py [new file with mode: 0644]
test/quantization/ao_migration/test_quantize.py [new file with mode: 0644]
test/test_quantization.py
torch/ao/quantization/__init__.py [new file with mode: 0644]
torch/ao/quantization/quantize.py [new file with mode: 0644]
torch/quantization/fx/convert.py
torch/quantization/fx/prepare.py
torch/quantization/fx/quantization_patterns.py
torch/quantization/fx/utils.py
torch/quantization/quantize.py

diff --git a/test/quantization/ao_migration/__init__.py b/test/quantization/ao_migration/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/test/quantization/ao_migration/test_quantize.py b/test/quantization/ao_migration/test_quantize.py
new file mode 100644 (file)
index 0000000..086364b
--- /dev/null
@@ -0,0 +1,63 @@
+from torch.testing._internal.common_utils import TestCase
+
+import importlib
+from typing import List
+
+
+class AOMigrationTestCase(TestCase):
+    def _test_package_import(self, package_name: str):
+        r"""Tests the module import by making sure that all the internals match
+        (except the dunder methods)."""
+        old_module = importlib.import_module(f'torch.quantization.{package_name}')
+        new_module = importlib.import_module(f'torch.ao.quantization.{package_name}')
+        old_module_dir = set(dir(old_module))
+        new_module_dir = set(dir(new_module))
+        # Remove magic modules from checking in subsets
+        for el in list(old_module_dir):
+            if el[:2] == '__' and el[-2:] == '__':
+                old_module_dir.remove(el)
+        assert (old_module_dir <= new_module_dir), \
+            f"Importing {old_module} vs. {new_module} does not match: " \
+            f"{old_module_dir - new_module_dir}"
+
+    def _test_function_import(self, package_name: str, function_list: List[str]):
+        r"""Tests individual function list import by comparing the functions
+        and their hashes."""
+        old_location = importlib.import_module(f'torch.quantization.{package_name}')
+        new_location = importlib.import_module(f'torch.ao.quantization.{package_name}')
+        for fn_name in function_list:
+            old_function = getattr(old_location, fn_name)
+            new_function = getattr(new_location, fn_name)
+            assert old_function == new_function, f"Functions don't match: {fn_name}"
+            assert hash(old_function) == hash(new_function), \
+                f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
+                f"{new_function}({hash(new_function)})"
+
+
+class TestAOMigrationQuantizePy(AOMigrationTestCase):
+    def test_package_import(self):
+        self._test_package_import('quantize')
+
+    def test_function_import(self):
+        function_list = [
+            '_convert',
+            '_observer_forward_hook',
+            '_propagate_qconfig_helper',
+            '_remove_activation_post_process',
+            '_remove_qconfig',
+            'add_observer_',
+            'add_quant_dequant',
+            'convert',
+            'get_observer_dict',
+            'get_unique_devices_',
+            'is_activation_post_process',
+            'prepare',
+            'prepare_qat',
+            'propagate_qconfig_',
+            'quantize',
+            'quantize_dynamic',
+            'quantize_qat',
+            'register_activation_post_process_hook',
+            'swap_module',
+        ]
+        self._test_function_import('quantize', function_list)
index 8671513..7fd4e50 100644 (file)
@@ -100,6 +100,9 @@ from quantization.jit.test_quantize_jit import TestQuantizeDynamicJitOps  # noqa
 from quantization.jit.test_fusion_passes import TestFusionPasses  # noqa: F401
 from quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized  # noqa: F401
 
+# AO Migration tests
+from quantization.ao_migration.test_quantize import TestAOMigrationQuantizePy  # noqa: F401
+
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py
new file mode 100644 (file)
index 0000000..54ff8fb
--- /dev/null
@@ -0,0 +1,580 @@
+import copy
+import itertools
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.quantized as nnq
+from torch.nn.intrinsic import _FusedModule
+
+from torch.quantization.quantization_mappings import (
+    get_default_dynamic_quant_module_mappings,
+    get_default_static_quant_module_mappings,
+    get_default_qat_module_mappings,
+    get_default_qconfig_propagation_list,
+    no_observer_set,
+    _has_special_act_post_process,
+    _get_special_act_post_process,
+)
+
+from torch.quantization.stubs import DeQuantStub, QuantWrapper
+from torch.quantization.qconfig import (
+    add_module_to_qconfig_obs_ctr,
+    default_dynamic_qconfig,
+    float16_dynamic_qconfig,
+    float_qparams_weight_only_qconfig)
+
+def is_activation_post_process(module):
+    return (isinstance(module, torch.quantization.ObserverBase) or
+            isinstance(module, torch.quantization.FakeQuantizeBase))
+
+def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
+                              qconfig_parent=None, prefix=''):
+    r"""This is a helper function for `propagate_qconfig_`
+
+    Args:
+        module: input module
+        qconfig_dict: dictionary that maps from name of submodule to quantization
+                     configuration
+        allow_list: list of quantizable modules
+        qconfig_parent: quantization config of parent module, we will fallback to
+                       this config when there is no specified config for current
+                       module
+        prefix: corresponding prefix of the current module, used as key in
+                qconfig_dict
+
+    Return:
+        None, module is modified inplace with qconfig attached
+    """
+    # TODO: Add test
+    if allow_list is None:
+        allow_list = get_default_qconfig_propagation_list()
+
+    module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
+    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
+    module_qconfig = getattr(module, 'qconfig', module_qconfig)
+
+    torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)
+
+    qconfig_with_device_check = add_module_to_qconfig_obs_ctr(module_qconfig, module)
+    module.qconfig = qconfig_with_device_check
+
+    for name, child in module.named_children():
+        module_prefix = prefix + '.' + name if prefix else name
+        _propagate_qconfig_helper(child, qconfig_dict, allow_list,
+                                  qconfig_with_device_check, module_prefix)
+
+# TODO(jerryzh): expose allow_list
+def propagate_qconfig_(module, qconfig_dict=None, allow_list=None):
+    r"""Propagate qconfig through the module hierarchy and assign `qconfig`
+    attribute on each leaf module
+
+    Args:
+        module: input module
+        qconfig_dict: dictionary that maps from name or type of submodule to
+            quantization configuration, qconfig applies to all submodules of a
+            given module unless qconfig for the submodules are specified (when
+            the submodule already has qconfig attribute)
+
+    Return:
+        None, module is modified inplace with qconfig attached
+    """
+    if qconfig_dict is None:
+        qconfig_dict = {}
+    _propagate_qconfig_helper(module, qconfig_dict, allow_list)
+
+def _observer_forward_hook(self, input, output):
+    r"""Forward hook that calls observer on the output
+    """
+    return self.activation_post_process(output)
+
+def register_activation_post_process_hook(module):
+    assert hasattr(module, 'activation_post_process'), \
+        'Expect activation_post_process attribut already attached to the module'
+    return module.register_forward_hook(_observer_forward_hook)
+
+def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
+    r"""Add observer for the leaf child of the module.
+
+    This function insert observer module to all leaf child module that
+    has a valid qconfig attribute.
+
+    Args:
+        module: input module with qconfig attributes for all the leaf modules that we want to quantize
+        device: parent device, if any
+        non_leaf_module_list: list of non-leaf modules we want to add observer
+
+    Return:
+        None, module is modified inplace with added observer modules and forward_hooks
+    """
+    if qconfig_propagation_list is None:
+        qconfig_propagation_list = get_default_qconfig_propagation_list()
+
+    if custom_module_class_mapping is None:
+        custom_module_class_mapping = {}
+
+    # respect device affinity when adding observers
+    if device is None:
+        devices = get_unique_devices_(module)
+        assert len(devices) <= 1, (
+            "add_observer_ only works with cpu or single-device CUDA modules, "
+            "but got devices {}".format(devices)
+        )
+        device = next(iter(devices)) if len(devices) > 0 else None
+
+    def get_activation_post_process(qconfig, device, special_act_post_process=None):
+        activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
+        if device is not None:
+            activation.to(device)
+        return activation
+
+    def needs_observation(m):
+        return hasattr(m, 'qconfig') and m.qconfig is not None
+
+    def insert_activation_post_process(m, special_act_post_process=None):
+        """ Adds an activation post process module and register
+        a post hook that calls the module
+        """
+        # We don't insert observer/fake_quantize for DeQuantStub
+        if needs_observation(m) and not isinstance(m, DeQuantStub):
+            # observer and hook will be gone after we swap the module
+            m.add_module('activation_post_process', get_activation_post_process(
+                m.qconfig, device, special_act_post_process))
+            # Register observer as the first entry in the hook list
+            # All post forward hooks are preserved and will be executed after the observer before convert
+            handle = register_activation_post_process_hook(m)
+            m._forward_hooks.move_to_end(handle.id, last=False)
+
+    for name, child in module.named_children():
+        if type(child) in [nnq.FloatFunctional, nnq.QFunctional]:
+            if needs_observation(child):
+                child.activation_post_process = get_activation_post_process(child.qconfig, device)
+        elif isinstance(child, _FusedModule):
+            # activation_post_process are now added directly to nn.Sequentail/_FusedModule
+            if needs_observation(child):
+                insert_activation_post_process(child)
+        elif _has_special_act_post_process(child):
+            special_act_post_process = _get_special_act_post_process(child)
+            insert_activation_post_process(child, special_act_post_process)
+        elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
+            if needs_observation(child):
+                insert_activation_post_process(child)
+        elif needs_observation(child) and type(child) in custom_module_class_mapping:
+            observed_child = custom_module_class_mapping[type(child)].from_float(child)
+            setattr(module, name, observed_child)
+            # TODO: These are the modules that cannot be observed
+            #       Once there are more, we should move them to a separate list
+            if custom_module_class_mapping[type(child)] not in no_observer_set():
+                insert_activation_post_process(observed_child)
+        else:
+            add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
+
+    # Insert observers only for leaf nodes, note that this observer is for
+    # the output of the module, for input QuantStub will observe them
+    if len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
+       and type(module) in qconfig_propagation_list:
+        insert_activation_post_process(module)
+
+def get_unique_devices_(module):
+    return {p.device for p in module.parameters()} | \
+        {p.device for p in module.buffers()}
+
+def add_quant_dequant(module):
+    r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
+    Note that this function will modify the children of module inplace and it
+    can return a new module which wraps the input module as well.
+
+    Args:
+        module: input module with qconfig attributes for all the leaf modules
+        that we want to quantize
+
+    Return:
+        Either the inplace modified module with submodules wrapped in
+        `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
+        wraps the input module, the latter case only happens when the input
+        module is a leaf module and we want to quantize it.
+    """
+    if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
+        return QuantWrapper(module)
+
+    for name, child in module.named_children():
+        module._modules[name] = add_quant_dequant(child)
+    return module
+
+def prepare(model, inplace=False, allow_list=None,
+            observer_non_leaf_module_list=None,
+            prepare_custom_config_dict=None):
+    r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
+
+    Quantization configuration should be assigned preemptively
+    to individual submodules in `.qconfig` attribute.
+
+    The model will be attached with observer or fake quant modules, and qconfig
+    will be propagated.
+
+    Args:
+        `model`: input model to be modified in-place
+        `inplace`: carry out model transformations in-place, the original module is mutated
+        `allow_list`: list of quantizable modules
+        `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
+        `prepare_custom_config_dict`: customization configuration dictionary for prepare function
+
+    .. code-block:: python
+
+       # Example of prepare_custom_config_dict:
+       prepare_custom_config_dict = {
+           # user will manually define the corresponding observed
+           # module class which has a from_float class method that converts
+           # float custom module to observed custom module
+           "float_to_observed_custom_module_class": {
+               CustomModule: ObservedCustomModule
+           }
+        }
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.prepare")
+    if prepare_custom_config_dict is None:
+        prepare_custom_config_dict = {}
+    custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
+
+    if not inplace:
+        model = copy.deepcopy(model)
+
+    # TODO: remove allow_list
+    qconfig_propagation_list = allow_list
+    if qconfig_propagation_list is None:
+        qconfig_propagation_list = get_default_qconfig_propagation_list()
+    propagate_qconfig_(model, qconfig_dict=None)
+
+    # sanity check common API misusage
+    if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
+        warnings.warn("None of the submodule got qconfig applied. Make sure you "
+                      "passed correct configuration through `qconfig_dict` or "
+                      "by assigning the `.qconfig` attribute directly on submodules")
+
+    add_observer_(
+        model, qconfig_propagation_list, observer_non_leaf_module_list,
+        custom_module_class_mapping=custom_module_class_mapping)
+    return model
+
+def _remove_activation_post_process(module):
+    # TODO: maybe we should change activation_post_process to _activation_post_process
+    # to prevent it from being used by user
+    if hasattr(module, 'activation_post_process') and \
+       is_activation_post_process(module.activation_post_process):
+        delattr(module, 'activation_post_process')
+
+    # remove activation_post_proceess hook
+    handle_ids_to_remove = set()
+    for handle_id, hook_fn in module._forward_hooks.items():
+        if hook_fn is _observer_forward_hook:
+            handle_ids_to_remove.add(handle_id)
+    for handle_id in handle_ids_to_remove:
+        module._forward_hooks.pop(handle_id)
+
+# TODO: rename to something more general
+def _remove_qconfig(module):
+    r"""Clean up the qconfig left in the module so that new qconfig can be
+    propagated.
+
+    Args:
+        module: module to be cleaned up
+    """
+    for child in module.children():
+        _remove_qconfig(child)
+
+    if hasattr(module, "qconfig"):
+        del module.qconfig
+
+    _remove_activation_post_process(module)
+
+def quantize(model, run_fn, run_args, mapping=None, inplace=False):
+    r"""Quantize the input float model with post training static quantization.
+
+    First it will prepare the model for calibration, then it calls
+    `run_fn` which will run the calibration step, after that we will
+    convert the model to a quantized model.
+
+    Args:
+        model: input float model
+        run_fn: a calibration function for calibrating the prepared model
+        run_args: positional arguments for `run_fn`
+        inplace: carry out model transformations in-place, the original module is mutated
+        mapping: correspondence between original module types and quantized counterparts
+
+    Return:
+        Quantized model.
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.quantize")
+    if mapping is None:
+        mapping = get_default_static_quant_module_mappings()
+    if not inplace:
+        model = copy.deepcopy(model)
+    model.eval()
+    prepare(model, inplace=True)
+    run_fn(model, *run_args)
+    convert(model, mapping, inplace=True)
+    return model
+
+def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
+                     mapping=None, inplace=False):
+    r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
+
+    Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
+
+    For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
+    by default is performed for layers with large weights size - i.e. Linear and RNN variants.
+
+    Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
+    If `qconfig` is provided, the `dtype` argument is ignored.
+
+    Args:
+        model: input model
+        qconfig_spec: Either:
+
+            - A dictionary that maps from name or type of submodule to quantization
+              configuration, qconfig applies to all submodules of a given
+              module unless qconfig for the submodules are specified (when the
+              submodule already has qconfig attribute). Entries in the dictionary
+              need to be QConfigDynamic instances.
+
+            - A set of types and/or submodule names to apply dynamic quantization to,
+              in which case the `dtype` argument is used to specify the bit-width
+
+        inplace: carry out model transformations in-place, the original module is mutated
+        mapping: maps type of a submodule to a type of corresponding dynamically quantized version
+            with which the submodule needs to be replaced
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
+    if qconfig_spec is None:
+        if dtype == torch.qint8:
+            qconfig_spec = {
+                nn.Linear : default_dynamic_qconfig,
+                nn.LSTM : default_dynamic_qconfig,
+                nn.GRU : default_dynamic_qconfig,
+                nn.LSTMCell : default_dynamic_qconfig,
+                nn.RNNCell : default_dynamic_qconfig,
+                nn.GRUCell : default_dynamic_qconfig,
+            }
+        elif dtype == torch.float16:
+            qconfig_spec = {
+                nn.Linear : float16_dynamic_qconfig,
+                nn.LSTM : float16_dynamic_qconfig,
+                nn.GRU : float16_dynamic_qconfig,
+                nn.LSTMCell : float16_dynamic_qconfig,
+                nn.RNNCell : float16_dynamic_qconfig,
+                nn.GRUCell : float16_dynamic_qconfig,
+            }
+        elif dtype == torch.quint8:
+            qconfig_spec = {
+                nn.EmbeddingBag : float_qparams_weight_only_qconfig,
+            }
+        else:
+            raise ValueError(
+                "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
+    elif isinstance(qconfig_spec, set):
+        if dtype is torch.qint8:
+            default_qconfig = default_dynamic_qconfig
+        elif dtype is torch.float16:
+            default_qconfig = float16_dynamic_qconfig
+        elif dtype is torch.quint8:
+            default_qconfig = float_qparams_weight_only_qconfig
+        else:
+            raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
+        qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
+
+    if mapping is None:
+        mapping = get_default_dynamic_quant_module_mappings()
+
+    if not inplace:
+        model = copy.deepcopy(model)
+    model.eval()
+    propagate_qconfig_(model, qconfig_spec)
+    convert(model, mapping, inplace=True)
+    return model
+
+def prepare_qat(model, mapping=None, inplace=False):
+    r"""
+    Prepares a copy of the model for quantization calibration or
+    quantization-aware training and converts it to quantized version.
+
+    Quantization configuration should be assigned preemptively
+    to individual submodules in `.qconfig` attribute.
+
+    Args:
+        model: input model to be modified in-place
+        mapping: dictionary that maps float modules to quantized modules to be
+                 replaced.
+        inplace: carry out model transformations in-place, the original module
+                 is mutated
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
+    if mapping is None:
+        mapping = get_default_qat_module_mappings()
+
+    if not inplace:
+        model = copy.deepcopy(model)
+
+    propagate_qconfig_(model, qconfig_dict=None)
+    convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
+    prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
+    return model
+
+def quantize_qat(model, run_fn, run_args, inplace=False):
+    r"""Do quantization aware training and output a quantized model
+
+    Args:
+        model: input model
+        run_fn: a function for evaluating the prepared model, can be a
+                function that simply runs the prepared model or a training
+                loop
+        run_args: positional arguments for `run_fn`
+
+    Return:
+        Quantized model.
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
+    if not inplace:
+        model = copy.deepcopy(model)
+    model.train()
+    prepare_qat(model, inplace=True)
+    run_fn(model, *run_args)
+    convert(model, inplace=True)
+    return model
+
+def convert(
+        module, mapping=None, inplace=False, remove_qconfig=True,
+        convert_custom_config_dict=None):
+    r"""Converts submodules in input module to a different module according to `mapping`
+    by calling `from_float` method on the target module class. And remove qconfig at the
+    end if remove_qconfig is set to True.
+
+    Args:
+        `module`: prepared and calibrated module
+        `mapping`: a dictionary that maps from source module type to target
+                   module type, can be overwritten to allow swapping user defined
+                   Modules
+        `inplace`: carry out model transformations in-place, the original module
+                   is mutated
+        `convert_custom_config_dict`: custom configuration dictionary for convert function
+
+    .. code-block:: python
+
+       # Example of convert_custom_config_dict:
+       convert_custom_config_dict = {
+           # user will manually define the corresponding quantized
+           # module class which has a from_observed class method that converts
+           # observed custom module to quantized custom module
+           "observed_to_quantized_custom_module_class": {
+               ObservedCustomModule: QuantizedCustomModule
+           }
+       }
+
+    """
+    torch._C._log_api_usage_once("quantization_api.quantize.convert")
+    if not inplace:
+        module = copy.deepcopy(module)
+    _convert(
+        module, mapping, inplace=True,
+        convert_custom_config_dict=convert_custom_config_dict)
+    if remove_qconfig:
+        _remove_qconfig(module)
+    return module
+
+def _convert(
+        module, mapping=None, inplace=False,
+        convert_custom_config_dict=None):
+    r"""Converts submodules in input module to a different module according to `mapping`
+    by calling `from_float` method on the target module class
+
+    Args:
+        module: input module
+        mapping: a dictionary that maps from source module type to target
+                 module type, can be overwritten to allow swapping user defined
+                 Modules
+        inplace: carry out model transformations in-place, the original module
+                 is mutated
+
+    """
+    if mapping is None:
+        mapping = get_default_static_quant_module_mappings()
+    if convert_custom_config_dict is None:
+        convert_custom_config_dict = {}
+    custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
+
+    if not inplace:
+        module = copy.deepcopy(module)
+    reassign = {}
+    for name, mod in module.named_children():
+        # both fused modules and observed custom modules are
+        # swapped as one unit
+        if not isinstance(mod, _FusedModule) and \
+           type(mod) not in custom_module_class_mapping:
+            _convert(mod, mapping, True,  # inplace
+                     convert_custom_config_dict)
+        reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
+
+    for key, value in reassign.items():
+        module._modules[key] = value
+
+    return module
+
+def swap_module(mod, mapping, custom_module_class_mapping):
+    r"""Swaps the module if it has a quantized counterpart and it has an
+    `observer` attached.
+
+    Args:
+        mod: input module
+        mapping: a dictionary that maps from nn module to nnq module
+
+    Return:
+        The corresponding quantized module of `mod`
+    """
+    new_mod = mod
+    if hasattr(mod, 'qconfig') and mod.qconfig is not None:
+        swapped = False
+        if type(mod) in custom_module_class_mapping:
+            new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
+            swapped = True
+        elif type(mod) in mapping:
+            new_mod = mapping[type(mod)].from_float(mod)
+            swapped = True
+
+        if swapped:
+            # Preserve module's pre forward hooks. They'll be called on quantized input
+            for pre_hook_fn in mod._forward_pre_hooks.values():
+                new_mod.register_forward_pre_hook(pre_hook_fn)
+            # Preserve module's post forward hooks except _observer_forward_hook
+            # After convert they'll work with quantized output
+            for hook_fn in mod._forward_hooks.values():
+                if hook_fn is not _observer_forward_hook:
+                    new_mod.register_forward_hook(hook_fn)
+
+            # respect device affinity when swapping modules
+            devices = get_unique_devices_(mod)
+            assert len(devices) <= 1, (
+                "swap_module only works with cpu or single-device CUDA modules, "
+                "but got devices {}".format(devices)
+            )
+            device = next(iter(devices)) if len(devices) > 0 else None
+            if device:
+                new_mod.to(device)
+    return new_mod
+
+def get_observer_dict(mod, target_dict, prefix=""):
+    r"""Traverse the modules and save all observers into dict.
+    This is mainly used for quantization accuracy debug
+    Args:
+        mod: the top module we want to save all observers
+        prefix: the prefix for the current module
+        target_dict: the dictionary used to save all the observers
+    """
+    def get_prefix(prefix):
+        return prefix if prefix == "" else prefix + '.'
+
+    if hasattr(mod, 'activation_post_process'):
+        target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
+    for name, child in mod.named_children():
+        module_prefix = get_prefix(prefix) + name if prefix else name
+        get_observer_dict(child, target_dict, module_prefix)
index 867b0b2..e00e4aa 100644 (file)
@@ -36,7 +36,7 @@ from .utils import (
     WEIGHT_INDEX_DICT,
 )
 
-from ..quantize import (
+from torch.ao.quantization.quantize import (
     _remove_qconfig,
     is_activation_post_process,
 )
index d2bb96a..0b65e33 100644 (file)
@@ -67,7 +67,7 @@ from ..quantization_mappings import (
     get_default_qat_module_mappings,
 )
 
-from ..quantize import (
+from torch.ao.quantization.quantize import (
     is_activation_post_process,
     convert
 )
index 3f54a6a..418cae1 100644 (file)
@@ -24,7 +24,7 @@ from ..utils import (
     get_qparam_dict,
 )
 
-from ..quantize import (
+from torch.ao.quantization.quantize import (
     is_activation_post_process,
 )
 
index 10f8b06..3c9adc2 100644 (file)
@@ -2,7 +2,7 @@ import re
 import torch
 import torch.nn as nn
 from ..utils import is_per_tensor, is_per_channel
-from ..quantize import is_activation_post_process
+from torch.ao.quantization.quantize import is_activation_post_process
 
 from torch.fx import GraphModule, map_arg
 
index 674ed59..5b0f4ed 100644 (file)
-import copy
-import itertools
-import warnings
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/quantization/quantize.py`, while adding an import statement
+here.
+"""
 
 import torch
-import torch.nn as nn
-import torch.nn.quantized as nnq
-from torch.nn.intrinsic import _FusedModule
 
-from .quantization_mappings import (
-    get_default_dynamic_quant_module_mappings,
-    get_default_static_quant_module_mappings,
-    get_default_qat_module_mappings,
-    get_default_qconfig_propagation_list,
-    no_observer_set,
-    _has_special_act_post_process,
-    _get_special_act_post_process,
-)
-
-from .stubs import DeQuantStub, QuantWrapper
-from .qconfig import (
-    add_module_to_qconfig_obs_ctr,
-    default_dynamic_qconfig,
-    float16_dynamic_qconfig,
-    float_qparams_weight_only_qconfig)
-
-def is_activation_post_process(module):
-    return (isinstance(module, torch.quantization.ObserverBase) or
-            isinstance(module, torch.quantization.FakeQuantizeBase))
-
-def _propagate_qconfig_helper(module, qconfig_dict, allow_list=None,
-                              qconfig_parent=None, prefix=''):
-    r"""This is a helper function for `propagate_qconfig_`
-
-    Args:
-        module: input module
-        qconfig_dict: dictionary that maps from name of submodule to quantization
-                     configuration
-        allow_list: list of quantizable modules
-        qconfig_parent: quantization config of parent module, we will fallback to
-                       this config when there is no specified config for current
-                       module
-        prefix: corresponding prefix of the current module, used as key in
-                qconfig_dict
-
-    Return:
-        None, module is modified inplace with qconfig attached
-    """
-    # TODO: Add test
-    if allow_list is None:
-        allow_list = get_default_qconfig_propagation_list()
-
-    module_qconfig = qconfig_dict.get(type(module), qconfig_parent)
-    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
-    module_qconfig = getattr(module, 'qconfig', module_qconfig)
-
-    torch.quantization.qconfig.assert_valid_qconfig(module_qconfig, module)
-
-    qconfig_with_device_check = add_module_to_qconfig_obs_ctr(module_qconfig, module)
-    module.qconfig = qconfig_with_device_check
-
-    for name, child in module.named_children():
-        module_prefix = prefix + '.' + name if prefix else name
-        _propagate_qconfig_helper(child, qconfig_dict, allow_list,
-                                  qconfig_with_device_check, module_prefix)
-
-# TODO(jerryzh): expose allow_list
-def propagate_qconfig_(module, qconfig_dict=None, allow_list=None):
-    r"""Propagate qconfig through the module hierarchy and assign `qconfig`
-    attribute on each leaf module
-
-    Args:
-        module: input module
-        qconfig_dict: dictionary that maps from name or type of submodule to
-            quantization configuration, qconfig applies to all submodules of a
-            given module unless qconfig for the submodules are specified (when
-            the submodule already has qconfig attribute)
-
-    Return:
-        None, module is modified inplace with qconfig attached
-    """
-    if qconfig_dict is None:
-        qconfig_dict = {}
-    _propagate_qconfig_helper(module, qconfig_dict, allow_list)
-
-def _observer_forward_hook(self, input, output):
-    r"""Forward hook that calls observer on the output
-    """
-    return self.activation_post_process(output)
-
-def register_activation_post_process_hook(module):
-    assert hasattr(module, 'activation_post_process'), \
-        'Expect activation_post_process attribut already attached to the module'
-    return module.register_forward_hook(_observer_forward_hook)
-
-def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
-    r"""Add observer for the leaf child of the module.
-
-    This function insert observer module to all leaf child module that
-    has a valid qconfig attribute.
-
-    Args:
-        module: input module with qconfig attributes for all the leaf modules that we want to quantize
-        device: parent device, if any
-        non_leaf_module_list: list of non-leaf modules we want to add observer
-
-    Return:
-        None, module is modified inplace with added observer modules and forward_hooks
-    """
-    if qconfig_propagation_list is None:
-        qconfig_propagation_list = get_default_qconfig_propagation_list()
-
-    if custom_module_class_mapping is None:
-        custom_module_class_mapping = {}
-
-    # respect device affinity when adding observers
-    if device is None:
-        devices = get_unique_devices_(module)
-        assert len(devices) <= 1, (
-            "add_observer_ only works with cpu or single-device CUDA modules, "
-            "but got devices {}".format(devices)
-        )
-        device = next(iter(devices)) if len(devices) > 0 else None
-
-    def get_activation_post_process(qconfig, device, special_act_post_process=None):
-        activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
-        if device is not None:
-            activation.to(device)
-        return activation
-
-    def needs_observation(m):
-        return hasattr(m, 'qconfig') and m.qconfig is not None
-
-    def insert_activation_post_process(m, special_act_post_process=None):
-        """ Adds an activation post process module and register
-        a post hook that calls the module
-        """
-        # We don't insert observer/fake_quantize for DeQuantStub
-        if needs_observation(m) and not isinstance(m, DeQuantStub):
-            # observer and hook will be gone after we swap the module
-            m.add_module('activation_post_process', get_activation_post_process(
-                m.qconfig, device, special_act_post_process))
-            # Register observer as the first entry in the hook list
-            # All post forward hooks are preserved and will be executed after the observer before convert
-            handle = register_activation_post_process_hook(m)
-            m._forward_hooks.move_to_end(handle.id, last=False)
-
-    for name, child in module.named_children():
-        if type(child) in [nnq.FloatFunctional, nnq.QFunctional]:
-            if needs_observation(child):
-                child.activation_post_process = get_activation_post_process(child.qconfig, device)
-        elif isinstance(child, _FusedModule):
-            # activation_post_process are now added directly to nn.Sequentail/_FusedModule
-            if needs_observation(child):
-                insert_activation_post_process(child)
-        elif _has_special_act_post_process(child):
-            special_act_post_process = _get_special_act_post_process(child)
-            insert_activation_post_process(child, special_act_post_process)
-        elif non_leaf_module_list is not None and type(child) in non_leaf_module_list:
-            if needs_observation(child):
-                insert_activation_post_process(child)
-        elif needs_observation(child) and type(child) in custom_module_class_mapping:
-            observed_child = custom_module_class_mapping[type(child)].from_float(child)
-            setattr(module, name, observed_child)
-            # TODO: These are the modules that cannot be observed
-            #       Once there are more, we should move them to a separate list
-            if custom_module_class_mapping[type(child)] not in no_observer_set():
-                insert_activation_post_process(observed_child)
-        else:
-            add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
-
-    # Insert observers only for leaf nodes, note that this observer is for
-    # the output of the module, for input QuantStub will observe them
-    if len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential) \
-       and type(module) in qconfig_propagation_list:
-        insert_activation_post_process(module)
-
-def get_unique_devices_(module):
-    return {p.device for p in module.parameters()} | \
-        {p.device for p in module.buffers()}
-
-def add_quant_dequant(module):
-    r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
-    Note that this function will modify the children of module inplace and it
-    can return a new module which wraps the input module as well.
-
-    Args:
-        module: input module with qconfig attributes for all the leaf modules
-        that we want to quantize
-
-    Return:
-        Either the inplace modified module with submodules wrapped in
-        `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
-        wraps the input module, the latter case only happens when the input
-        module is a leaf module and we want to quantize it.
-    """
-    if len(module._modules) == 0 and hasattr(module, 'qconfig') and module.qconfig:
-        return QuantWrapper(module)
-
-    for name, child in module.named_children():
-        module._modules[name] = add_quant_dequant(child)
-    return module
-
-def prepare(model, inplace=False, allow_list=None,
-            observer_non_leaf_module_list=None,
-            prepare_custom_config_dict=None):
-    r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
-
-    Quantization configuration should be assigned preemptively
-    to individual submodules in `.qconfig` attribute.
-
-    The model will be attached with observer or fake quant modules, and qconfig
-    will be propagated.
-
-    Args:
-        `model`: input model to be modified in-place
-        `inplace`: carry out model transformations in-place, the original module is mutated
-        `allow_list`: list of quantizable modules
-        `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
-        `prepare_custom_config_dict`: customization configuration dictionary for prepare function
-
-    .. code-block:: python
-
-       # Example of prepare_custom_config_dict:
-       prepare_custom_config_dict = {
-           # user will manually define the corresponding observed
-           # module class which has a from_float class method that converts
-           # float custom module to observed custom module
-           "float_to_observed_custom_module_class": {
-               CustomModule: ObservedCustomModule
-           }
-        }
-
-    """
-    torch._C._log_api_usage_once("quantization_api.quantize.prepare")
-    if prepare_custom_config_dict is None:
-        prepare_custom_config_dict = {}
-    custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
-
-    if not inplace:
-        model = copy.deepcopy(model)
-
-    # TODO: remove allow_list
-    qconfig_propagation_list = allow_list
-    if qconfig_propagation_list is None:
-        qconfig_propagation_list = get_default_qconfig_propagation_list()
-    propagate_qconfig_(model, qconfig_dict=None)
-
-    # sanity check common API misusage
-    if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
-        warnings.warn("None of the submodule got qconfig applied. Make sure you "
-                      "passed correct configuration through `qconfig_dict` or "
-                      "by assigning the `.qconfig` attribute directly on submodules")
-
-    add_observer_(
-        model, qconfig_propagation_list, observer_non_leaf_module_list,
-        custom_module_class_mapping=custom_module_class_mapping)
-    return model
-
-def _remove_activation_post_process(module):
-    # TODO: maybe we should change activation_post_process to _activation_post_process
-    # to prevent it from being used by user
-    if hasattr(module, 'activation_post_process') and \
-       is_activation_post_process(module.activation_post_process):
-        delattr(module, 'activation_post_process')
-
-    # remove activation_post_proceess hook
-    handle_ids_to_remove = set()
-    for handle_id, hook_fn in module._forward_hooks.items():
-        if hook_fn is _observer_forward_hook:
-            handle_ids_to_remove.add(handle_id)
-    for handle_id in handle_ids_to_remove:
-        module._forward_hooks.pop(handle_id)
-
-# TODO: rename to something more general
-def _remove_qconfig(module):
-    r"""Clean up the qconfig left in the module so that new qconfig can be
-    propagated.
-
-    Args:
-        module: module to be cleaned up
-    """
-    for child in module.children():
-        _remove_qconfig(child)
-
-    if hasattr(module, "qconfig"):
-        del module.qconfig
-
-    _remove_activation_post_process(module)
-
-def quantize(model, run_fn, run_args, mapping=None, inplace=False):
-    r"""Quantize the input float model with post training static quantization.
-
-    First it will prepare the model for calibration, then it calls
-    `run_fn` which will run the calibration step, after that we will
-    convert the model to a quantized model.
-
-    Args:
-        model: input float model
-        run_fn: a calibration function for calibrating the prepared model
-        run_args: positional arguments for `run_fn`
-        inplace: carry out model transformations in-place, the original module is mutated
-        mapping: correspondence between original module types and quantized counterparts
-
-    Return:
-        Quantized model.
-    """
-    torch._C._log_api_usage_once("quantization_api.quantize.quantize")
-    if mapping is None:
-        mapping = get_default_static_quant_module_mappings()
-    if not inplace:
-        model = copy.deepcopy(model)
-    model.eval()
-    prepare(model, inplace=True)
-    run_fn(model, *run_args)
-    convert(model, mapping, inplace=True)
-    return model
-
-def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
-                     mapping=None, inplace=False):
-    r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
-
-    Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
-
-    For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
-    by default is performed for layers with large weights size - i.e. Linear and RNN variants.
-
-    Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
-    If `qconfig` is provided, the `dtype` argument is ignored.
-
-    Args:
-        model: input model
-        qconfig_spec: Either:
-
-            - A dictionary that maps from name or type of submodule to quantization
-              configuration, qconfig applies to all submodules of a given
-              module unless qconfig for the submodules are specified (when the
-              submodule already has qconfig attribute). Entries in the dictionary
-              need to be QConfigDynamic instances.
-
-            - A set of types and/or submodule names to apply dynamic quantization to,
-              in which case the `dtype` argument is used to specify the bit-width
-
-        inplace: carry out model transformations in-place, the original module is mutated
-        mapping: maps type of a submodule to a type of corresponding dynamically quantized version
-            with which the submodule needs to be replaced
-
-    """
-    torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
-    if qconfig_spec is None:
-        if dtype == torch.qint8:
-            qconfig_spec = {
-                nn.Linear : default_dynamic_qconfig,
-                nn.LSTM : default_dynamic_qconfig,
-                nn.GRU : default_dynamic_qconfig,
-                nn.LSTMCell : default_dynamic_qconfig,
-                nn.RNNCell : default_dynamic_qconfig,
-                nn.GRUCell : default_dynamic_qconfig,
-            }
-        elif dtype == torch.float16:
-            qconfig_spec = {
-                nn.Linear : float16_dynamic_qconfig,
-                nn.LSTM : float16_dynamic_qconfig,
-                nn.GRU : float16_dynamic_qconfig,
-                nn.LSTMCell : float16_dynamic_qconfig,
-                nn.RNNCell : float16_dynamic_qconfig,
-                nn.GRUCell : float16_dynamic_qconfig,
-            }
-        elif dtype == torch.quint8:
-            qconfig_spec = {
-                nn.EmbeddingBag : float_qparams_weight_only_qconfig,
-            }
-        else:
-            raise ValueError(
-                "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
-    elif isinstance(qconfig_spec, set):
-        if dtype is torch.qint8:
-            default_qconfig = default_dynamic_qconfig
-        elif dtype is torch.float16:
-            default_qconfig = float16_dynamic_qconfig
-        elif dtype is torch.quint8:
-            default_qconfig = float_qparams_weight_only_qconfig
-        else:
-            raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
-        qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
-
-    if mapping is None:
-        mapping = get_default_dynamic_quant_module_mappings()
-
-    if not inplace:
-        model = copy.deepcopy(model)
-    model.eval()
-    propagate_qconfig_(model, qconfig_spec)
-    convert(model, mapping, inplace=True)
-    return model
-
-def prepare_qat(model, mapping=None, inplace=False):
-    r"""
-    Prepares a copy of the model for quantization calibration or
-    quantization-aware training and converts it to quantized version.
-
-    Quantization configuration should be assigned preemptively
-    to individual submodules in `.qconfig` attribute.
-
-    Args:
-        model: input model to be modified in-place
-        mapping: dictionary that maps float modules to quantized modules to be
-                 replaced.
-        inplace: carry out model transformations in-place, the original module
-                 is mutated
-    """
-    torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
-    if mapping is None:
-        mapping = get_default_qat_module_mappings()
-
-    if not inplace:
-        model = copy.deepcopy(model)
-
-    propagate_qconfig_(model, qconfig_dict=None)
-    convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
-    prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
-    return model
-
-def quantize_qat(model, run_fn, run_args, inplace=False):
-    r"""Do quantization aware training and output a quantized model
-
-    Args:
-        model: input model
-        run_fn: a function for evaluating the prepared model, can be a
-                function that simply runs the prepared model or a training
-                loop
-        run_args: positional arguments for `run_fn`
-
-    Return:
-        Quantized model.
-    """
-    torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
-    if not inplace:
-        model = copy.deepcopy(model)
-    model.train()
-    prepare_qat(model, inplace=True)
-    run_fn(model, *run_args)
-    convert(model, inplace=True)
-    return model
-
-def convert(
-        module, mapping=None, inplace=False, remove_qconfig=True,
-        convert_custom_config_dict=None):
-    r"""Converts submodules in input module to a different module according to `mapping`
-    by calling `from_float` method on the target module class. And remove qconfig at the
-    end if remove_qconfig is set to True.
-
-    Args:
-        `module`: prepared and calibrated module
-        `mapping`: a dictionary that maps from source module type to target
-                   module type, can be overwritten to allow swapping user defined
-                   Modules
-        `inplace`: carry out model transformations in-place, the original module
-                   is mutated
-        `convert_custom_config_dict`: custom configuration dictionary for convert function
-
-    .. code-block:: python
-
-       # Example of convert_custom_config_dict:
-       convert_custom_config_dict = {
-           # user will manually define the corresponding quantized
-           # module class which has a from_observed class method that converts
-           # observed custom module to quantized custom module
-           "observed_to_quantized_custom_module_class": {
-               ObservedCustomModule: QuantizedCustomModule
-           }
-       }
-
-    """
-    torch._C._log_api_usage_once("quantization_api.quantize.convert")
-    if not inplace:
-        module = copy.deepcopy(module)
-    _convert(
-        module, mapping, inplace=True,
-        convert_custom_config_dict=convert_custom_config_dict)
-    if remove_qconfig:
-        _remove_qconfig(module)
-    return module
-
-def _convert(
-        module, mapping=None, inplace=False,
-        convert_custom_config_dict=None):
-    r"""Converts submodules in input module to a different module according to `mapping`
-    by calling `from_float` method on the target module class
-
-    Args:
-        module: input module
-        mapping: a dictionary that maps from source module type to target
-                 module type, can be overwritten to allow swapping user defined
-                 Modules
-        inplace: carry out model transformations in-place, the original module
-                 is mutated
-
-    """
-    if mapping is None:
-        mapping = get_default_static_quant_module_mappings()
-    if convert_custom_config_dict is None:
-        convert_custom_config_dict = {}
-    custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
-
-    if not inplace:
-        module = copy.deepcopy(module)
-    reassign = {}
-    for name, mod in module.named_children():
-        # both fused modules and observed custom modules are
-        # swapped as one unit
-        if not isinstance(mod, _FusedModule) and \
-           type(mod) not in custom_module_class_mapping:
-            _convert(mod, mapping, True,  # inplace
-                     convert_custom_config_dict)
-        reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
-
-    for key, value in reassign.items():
-        module._modules[key] = value
-
-    return module
-
-def swap_module(mod, mapping, custom_module_class_mapping):
-    r"""Swaps the module if it has a quantized counterpart and it has an
-    `observer` attached.
-
-    Args:
-        mod: input module
-        mapping: a dictionary that maps from nn module to nnq module
-
-    Return:
-        The corresponding quantized module of `mod`
-    """
-    new_mod = mod
-    if hasattr(mod, 'qconfig') and mod.qconfig is not None:
-        swapped = False
-        if type(mod) in custom_module_class_mapping:
-            new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
-            swapped = True
-        elif type(mod) in mapping:
-            new_mod = mapping[type(mod)].from_float(mod)
-            swapped = True
-
-        if swapped:
-            # Preserve module's pre forward hooks. They'll be called on quantized input
-            for pre_hook_fn in mod._forward_pre_hooks.values():
-                new_mod.register_forward_pre_hook(pre_hook_fn)
-            # Preserve module's post forward hooks except _observer_forward_hook
-            # After convert they'll work with quantized output
-            for hook_fn in mod._forward_hooks.values():
-                if hook_fn is not _observer_forward_hook:
-                    new_mod.register_forward_hook(hook_fn)
-
-            # respect device affinity when swapping modules
-            devices = get_unique_devices_(mod)
-            assert len(devices) <= 1, (
-                "swap_module only works with cpu or single-device CUDA modules, "
-                "but got devices {}".format(devices)
-            )
-            device = next(iter(devices)) if len(devices) > 0 else None
-            if device:
-                new_mod.to(device)
-    return new_mod
-
-def get_observer_dict(mod, target_dict, prefix=""):
-    r"""Traverse the modules and save all observers into dict.
-    This is mainly used for quantization accuracy debug
-    Args:
-        mod: the top module we want to save all observers
-        prefix: the prefix for the current module
-        target_dict: the dictionary used to save all the observers
-    """
-    def get_prefix(prefix):
-        return prefix if prefix == "" else prefix + '.'
-
-    if hasattr(mod, 'activation_post_process'):
-        target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
-    for name, child in mod.named_children():
-        module_prefix = get_prefix(prefix) + name if prefix else name
-        get_observer_dict(child, target_dict, module_prefix)
+from torch.ao.quantization.quantize import _convert
+from torch.ao.quantization.quantize import _observer_forward_hook
+from torch.ao.quantization.quantize import _propagate_qconfig_helper
+from torch.ao.quantization.quantize import _remove_activation_post_process
+from torch.ao.quantization.quantize import _remove_qconfig
+from torch.ao.quantization.quantize import add_observer_
+from torch.ao.quantization.quantize import add_quant_dequant
+from torch.ao.quantization.quantize import convert
+from torch.ao.quantization.quantize import get_observer_dict
+from torch.ao.quantization.quantize import get_unique_devices_
+from torch.ao.quantization.quantize import is_activation_post_process
+from torch.ao.quantization.quantize import prepare
+from torch.ao.quantization.quantize import prepare_qat
+from torch.ao.quantization.quantize import propagate_qconfig_
+from torch.ao.quantization.quantize import quantize
+from torch.ao.quantization.quantize import quantize_dynamic
+from torch.ao.quantization.quantize import quantize_qat
+from torch.ao.quantization.quantize import register_activation_post_process_hook
+from torch.ao.quantization.quantize import swap_module