[quant] AO migration of the `fuse_modules.py` (phase 1) (#64913)
authorZafar Takhirov <zaf@fb.com>
Thu, 16 Sep 2021 00:24:09 +0000 (17:24 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 00:28:47 +0000 (17:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64913

AO Team is migrating the existing torch.quantization into torch.ao.quantization. We are doing it one file at a time to make sure that the internal callsites are updated properly.
This migrates the fuse_module.py from torch.quantization to torch.ao.quantization.
At this point both locations will be supported. Eventually the torch.quantization will be deprecated.

Test Plan: `buck test mode/dev //caffe2/test:quantization`

Reviewed By: vkuzo

Differential Revision: D30882819

fbshipit-source-id: 1926ad6aa49136aceb5b625dcef4bfde3a2860d4

test/quantization/ao_migration/test_quantization.py
test/quantization/core/test_workflow_module.py
test/quantization/eager/test_equalize_eager.py
test/quantization/eager/test_model_numerics.py
test/quantization/eager/test_quantize_eager_ptq.py
torch/ao/quantization/__init__.py
torch/ao/quantization/fuse_modules.py [new file with mode: 0644]
torch/quantization/fuse_modules.py

index 72893d1..daf62f2 100644 (file)
@@ -93,3 +93,19 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
             'enable_observer',
         ]
         self._test_function_import('fake_quantize', function_list)
+
+    def test_package_import_fuse_modules(self):
+        self._test_package_import('fuse_modules')
+
+    def test_function_import_fuse_modules(self):
+        function_list = [
+            '_fuse_modules',
+            '_get_module',
+            '_set_module',
+            'fuse_conv_bn',
+            'fuse_conv_bn_relu',
+            'fuse_known_modules',
+            'fuse_modules',
+            'get_fuser_method',
+        ]
+        self._test_function_import('fuse_modules', function_list)
index b7782ec..7aaefdd 100644 (file)
@@ -794,7 +794,7 @@ class TestDistributed(QuantizationTestCase):
                 torch.quantization.DeQuantStub(),
             )
 
-            torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
+            torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
 
             model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
             torch.quantization.prepare_qat(model, inplace=True)
@@ -835,7 +835,7 @@ class TestDistributed(QuantizationTestCase):
 
             model = Model()
             # fuse it
-            fused_model = torch.quantization.fuse_modules(
+            fused_model = torch.ao.quantization.fuse_modules(
                 model,
                 [['conv', 'bn']],
             )
index d2c6710..7e0bfb5 100644 (file)
@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 
 from torch.testing._internal.common_quantization import QuantizationTestCase
-from torch.quantization.fuse_modules import fuse_modules
+from torch.ao.quantization.fuse_modules import fuse_modules
 
 import torch.quantization._equalize as _equalize
 
index e384d4c..6d3f28f 100644 (file)
@@ -23,7 +23,7 @@ class TestModelNumericsEager(QuantizationTestCase):
                 qModel = torch.quantization.QuantWrapper(my_model)
                 qModel.eval()
                 qModel.qconfig = torch.quantization.default_qconfig
-                torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+                torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
                 torch.quantization.prepare(qModel, inplace=True)
                 qModel(calib_data)
                 torch.quantization.convert(qModel, inplace=True)
@@ -45,7 +45,7 @@ class TestModelNumericsEager(QuantizationTestCase):
         q_model = torch.quantization.QuantWrapper(my_model)
         q_model.eval()
         q_model.qconfig = torch.quantization.default_per_channel_qconfig
-        torch.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+        torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
         torch.quantization.prepare(q_model)
         q_model(calib_data)
         torch.quantization.convert(q_model)
@@ -67,7 +67,7 @@ class TestModelNumericsEager(QuantizationTestCase):
                 fq_model = torch.quantization.QuantWrapper(my_model)
                 fq_model.train()
                 fq_model.qconfig = torch.quantization.default_qat_qconfig
-                torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+                torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
                 torch.quantization.prepare_qat(fq_model)
                 fq_model.eval()
                 fq_model.apply(torch.quantization.disable_fake_quant)
@@ -103,7 +103,7 @@ class TestModelNumericsEager(QuantizationTestCase):
                     fq_model = torch.quantization.QuantWrapper(my_model)
                     fq_model.train()
                     fq_model.qconfig = qconfig
-                    torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+                    torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
                     torch.quantization.prepare_qat(fq_model)
                     fq_model.eval()
                     fq_model.apply(torch.quantization.disable_fake_quant)
index 10cbd92..144c209 100644 (file)
@@ -1001,7 +1001,7 @@ class TestPostTrainingDynamic(QuantizationTestCase):
         model = LinearReluLinearModel().eval()
         qconfig = default_dynamic_qconfig
         qconfig_dict = {'' : qconfig}
-        torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
+        torch.ao.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
         prepare_dynamic(model, qconfig_dict)
         convert_dynamic(model)
 
index 51029f9..245188e 100644 (file)
@@ -1,11 +1,3 @@
 from .fake_quantize import *  # noqa: F403
-
-# TODO(future PR): fix the typo, should be `__all__`
-_all__ = [
-    # FakeQuantize (for qat)
-    'default_fake_quant', 'default_weight_fake_quant',
-    'default_symmetric_fixed_qparams_fake_quant',
-    'default_affine_fixed_qparams_fake_quant',
-    'default_per_channel_weight_fake_quant',
-    'default_histogram_fake_quant',
-]
+from .fuse_modules import *  # noqa: F403
+from .quantize import *  # noqa: F403
diff --git a/torch/ao/quantization/fuse_modules.py b/torch/ao/quantization/fuse_modules.py
new file mode 100644 (file)
index 0000000..f702bdf
--- /dev/null
@@ -0,0 +1,147 @@
+
+import copy
+
+import torch.nn as nn
+
+from torch.quantization.fuser_method_mappings import get_fuser_method
+# for backward compatiblity
+from torch.quantization.fuser_method_mappings import fuse_conv_bn  # noqa: F401
+from torch.quantization.fuser_method_mappings import fuse_conv_bn_relu  # noqa: F401
+
+from typing import List, Optional
+
+# Generalization of getattr
+def _get_module(model, submodule_key):
+    tokens = submodule_key.split('.')
+    cur_mod = model
+    for s in tokens:
+        cur_mod = getattr(cur_mod, s)
+    return cur_mod
+
+# Generalization of setattr
+def _set_module(model, submodule_key, module):
+    tokens = submodule_key.split('.')
+    sub_tokens = tokens[:-1]
+    cur_mod = model
+    for s in sub_tokens:
+        cur_mod = getattr(cur_mod, s)
+
+    setattr(cur_mod, tokens[-1], module)
+
+def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
+    r"""Returns a list of modules that fuses the operations specified
+     in the input module list.
+
+    Fuses only the following sequence of modules:
+    conv, bn
+    conv, bn, relu
+    conv, relu
+    linear, bn
+    linear, relu
+    For these sequences, the first element in the output module list performs
+    the fused operation. The rest of the elements are set to nn.Identity()
+    """
+    types = tuple(type(m) for m in mod_list)
+    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
+    if fuser_method is None:
+        raise NotImplementedError("Cannot fuse modules: {}".format(types))
+    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
+    fused = fuser_method(*mod_list)
+    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
+    # Move pre forward hooks of the base module to resulting fused module
+    for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
+        fused.register_forward_pre_hook(pre_hook_fn)
+        del mod_list[0]._forward_pre_hooks[handle_id]
+    # Move post forward hooks of the last module to resulting fused module
+    for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
+        fused.register_forward_hook(hook_fn)
+        del mod_list[-1]._forward_hooks[handle_id]
+    new_mod[0] = fused
+
+    for i in range(1, len(mod_list)):
+        identity = nn.Identity()
+        identity.training = mod_list[0].training
+        new_mod[i] = identity
+
+    return new_mod
+
+def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+    if fuse_custom_config_dict is None:
+        fuse_custom_config_dict = {}
+    additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
+    mod_list = []
+    for item in modules_to_fuse:
+        mod_list.append(_get_module(model, item))
+
+    # Fuse list of modules
+    new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
+
+    # Replace original module list with fused module list
+    for i, item in enumerate(modules_to_fuse):
+        _set_module(model, item, new_mod_list[i])
+
+def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+    r"""Fuses a list of modules into a single module
+
+    Fuses only the following sequence of modules:
+    conv, bn
+    conv, bn, relu
+    conv, relu
+    linear, relu
+    bn, relu
+    All other sequences are left unchanged.
+    For these sequences, replaces the first item in the list
+    with the fused module, replacing the rest of the modules
+    with identity.
+
+    Args:
+        model: Model containing the modules to be fused
+        modules_to_fuse: list of list of module names to fuse. Can also be a list
+                         of strings if there is only a single list of modules to fuse.
+        inplace: bool specifying if fusion happens in place on the model, by default
+                 a new model is returned
+        fuser_func: Function that takes in a list of modules and outputs a list of fused modules
+                    of the same length. For example,
+                    fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
+                    Defaults to torch.quantization.fuse_known_modules
+        `fuse_custom_config_dict`: custom configuration for fusion
+
+    .. code-block:: python
+
+       # Example of fuse_custom_config_dict
+       fuse_custom_config_dict = {
+           # Additional fuser_method mapping
+           "additional_fuser_method_mapping": {
+               (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
+           },
+       }
+
+    Returns:
+        model with fused modules. A new copy is created if inplace=True.
+
+    Examples::
+
+            >>> m = myModel()
+            >>> # m is a module containing  the sub-modules below
+            >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
+            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
+            >>> output = fused_m(input)
+
+            >>> m = myModel()
+            >>> # Alternately provide a single list of modules to fuse
+            >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
+            >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
+            >>> output = fused_m(input)
+
+    """
+    if not inplace:
+        model = copy.deepcopy(model)
+
+    if all(isinstance(module_element, str) for module_element in modules_to_fuse):
+        # Handle case of modules_to_fuse being a list
+        _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
+    else:
+        # Handle case of modules_to_fuse being a list of lists
+        for module_list in modules_to_fuse:
+            _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
+    return model
index d4eec00..896f357 100644 (file)
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/quantization/fuse_modules.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.quantization.fuse_modules import fuse_modules
+from torch.ao.quantization.fuse_modules import fuse_known_modules
+from torch.ao.quantization.fuse_modules import get_fuser_method
 
-import copy
-
-import torch.nn as nn
-
-from .fuser_method_mappings import get_fuser_method
 # for backward compatiblity
-from .fuser_method_mappings import fuse_conv_bn  # noqa: F401
-from .fuser_method_mappings import fuse_conv_bn_relu  # noqa: F401
-
-from typing import List, Optional
-
-# Generalization of getattr
-def _get_module(model, submodule_key):
-    tokens = submodule_key.split('.')
-    cur_mod = model
-    for s in tokens:
-        cur_mod = getattr(cur_mod, s)
-    return cur_mod
-
-# Generalization of setattr
-def _set_module(model, submodule_key, module):
-    tokens = submodule_key.split('.')
-    sub_tokens = tokens[:-1]
-    cur_mod = model
-    for s in sub_tokens:
-        cur_mod = getattr(cur_mod, s)
-
-    setattr(cur_mod, tokens[-1], module)
-
-def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
-    r"""Returns a list of modules that fuses the operations specified
-     in the input module list.
-
-    Fuses only the following sequence of modules:
-    conv, bn
-    conv, bn, relu
-    conv, relu
-    linear, bn
-    linear, relu
-    For these sequences, the first element in the output module list performs
-    the fused operation. The rest of the elements are set to nn.Identity()
-    """
-    types = tuple(type(m) for m in mod_list)
-    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
-    if fuser_method is None:
-        raise NotImplementedError("Cannot fuse modules: {}".format(types))
-    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
-    fused = fuser_method(*mod_list)
-    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
-    # Move pre forward hooks of the base module to resulting fused module
-    for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
-        fused.register_forward_pre_hook(pre_hook_fn)
-        del mod_list[0]._forward_pre_hooks[handle_id]
-    # Move post forward hooks of the last module to resulting fused module
-    for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
-        fused.register_forward_hook(hook_fn)
-        del mod_list[-1]._forward_hooks[handle_id]
-    new_mod[0] = fused
-
-    for i in range(1, len(mod_list)):
-        identity = nn.Identity()
-        identity.training = mod_list[0].training
-        new_mod[i] = identity
-
-    return new_mod
-
-def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
-    if fuse_custom_config_dict is None:
-        fuse_custom_config_dict = {}
-    additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
-    mod_list = []
-    for item in modules_to_fuse:
-        mod_list.append(_get_module(model, item))
-
-    # Fuse list of modules
-    new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
-
-    # Replace original module list with fused module list
-    for i, item in enumerate(modules_to_fuse):
-        _set_module(model, item, new_mod_list[i])
-
-def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
-    r"""Fuses a list of modules into a single module
-
-    Fuses only the following sequence of modules:
-    conv, bn
-    conv, bn, relu
-    conv, relu
-    linear, relu
-    bn, relu
-    All other sequences are left unchanged.
-    For these sequences, replaces the first item in the list
-    with the fused module, replacing the rest of the modules
-    with identity.
-
-    Args:
-        model: Model containing the modules to be fused
-        modules_to_fuse: list of list of module names to fuse. Can also be a list
-                         of strings if there is only a single list of modules to fuse.
-        inplace: bool specifying if fusion happens in place on the model, by default
-                 a new model is returned
-        fuser_func: Function that takes in a list of modules and outputs a list of fused modules
-                    of the same length. For example,
-                    fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
-                    Defaults to torch.quantization.fuse_known_modules
-        `fuse_custom_config_dict`: custom configuration for fusion
-
-    .. code-block:: python
-
-       # Example of fuse_custom_config_dict
-       fuse_custom_config_dict = {
-           # Additional fuser_method mapping
-           "additional_fuser_method_mapping": {
-               (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
-           },
-       }
-
-    Returns:
-        model with fused modules. A new copy is created if inplace=True.
-
-    Examples::
-
-            >>> m = myModel()
-            >>> # m is a module containing  the sub-modules below
-            >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
-            >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
-            >>> output = fused_m(input)
-
-            >>> m = myModel()
-            >>> # Alternately provide a single list of modules to fuse
-            >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
-            >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
-            >>> output = fused_m(input)
-
-    """
-    if not inplace:
-        model = copy.deepcopy(model)
-
-    if all(isinstance(module_element, str) for module_element in modules_to_fuse):
-        # Handle case of modules_to_fuse being a list
-        _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
-    else:
-        # Handle case of modules_to_fuse being a list of lists
-        for module_list in modules_to_fuse:
-            _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
-    return model
+from torch.quantization.fuser_method_mappings import fuse_conv_bn
+from torch.quantization.fuser_method_mappings import fuse_conv_bn_relu
+
+# TODO: These functions are not used outside the `fuse_modules.py`
+#       Keeping here for now, need to remove them later.
+from torch.ao.quantization.fuse_modules import (
+    _fuse_modules,
+    _get_module,
+    _set_module,
+)