From a42996f16e165de2c7f301da74a4de7eb3c0f452 Mon Sep 17 00:00:00 2001 From: Zafar Takhirov Date: Wed, 15 Sep 2021 17:24:09 -0700 Subject: [PATCH] [quant] AO migration of the `fuse_modules.py` (phase 1) (#64913) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64913 AO Team is migrating the existing torch.quantization into torch.ao.quantization. We are doing it one file at a time to make sure that the internal callsites are updated properly. This migrates the fuse_module.py from torch.quantization to torch.ao.quantization. At this point both locations will be supported. Eventually the torch.quantization will be deprecated. Test Plan: `buck test mode/dev //caffe2/test:quantization` Reviewed By: vkuzo Differential Revision: D30882819 fbshipit-source-id: 1926ad6aa49136aceb5b625dcef4bfde3a2860d4 --- .../quantization/ao_migration/test_quantization.py | 16 ++ test/quantization/core/test_workflow_module.py | 4 +- test/quantization/eager/test_equalize_eager.py | 2 +- test/quantization/eager/test_model_numerics.py | 8 +- test/quantization/eager/test_quantize_eager_ptq.py | 2 +- torch/ao/quantization/__init__.py | 12 +- torch/ao/quantization/fuse_modules.py | 147 ++++++++++++++++++ torch/quantization/fuse_modules.py | 167 +++------------------ 8 files changed, 195 insertions(+), 163 deletions(-) create mode 100644 torch/ao/quantization/fuse_modules.py diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index 72893d1..daf62f2 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -93,3 +93,19 @@ class TestAOMigrationQuantization(AOMigrationTestCase): 'enable_observer', ] self._test_function_import('fake_quantize', function_list) + + def test_package_import_fuse_modules(self): + self._test_package_import('fuse_modules') + + def test_function_import_fuse_modules(self): + function_list = [ + '_fuse_modules', + '_get_module', + '_set_module', + 'fuse_conv_bn', + 'fuse_conv_bn_relu', + 'fuse_known_modules', + 'fuse_modules', + 'get_fuser_method', + ] + self._test_function_import('fuse_modules', function_list) diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index b7782ec..7aaefdd 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -794,7 +794,7 @@ class TestDistributed(QuantizationTestCase): torch.quantization.DeQuantStub(), ) - torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True) + torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True) model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') torch.quantization.prepare_qat(model, inplace=True) @@ -835,7 +835,7 @@ class TestDistributed(QuantizationTestCase): model = Model() # fuse it - fused_model = torch.quantization.fuse_modules( + fused_model = torch.ao.quantization.fuse_modules( model, [['conv', 'bn']], ) diff --git a/test/quantization/eager/test_equalize_eager.py b/test/quantization/eager/test_equalize_eager.py index d2c6710..7e0bfb5 100644 --- a/test/quantization/eager/test_equalize_eager.py +++ b/test/quantization/eager/test_equalize_eager.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from torch.testing._internal.common_quantization import QuantizationTestCase -from torch.quantization.fuse_modules import fuse_modules +from torch.ao.quantization.fuse_modules import fuse_modules import torch.quantization._equalize as _equalize diff --git a/test/quantization/eager/test_model_numerics.py b/test/quantization/eager/test_model_numerics.py index e384d4c..6d3f28f 100644 --- a/test/quantization/eager/test_model_numerics.py +++ b/test/quantization/eager/test_model_numerics.py @@ -23,7 +23,7 @@ class TestModelNumericsEager(QuantizationTestCase): qModel = torch.quantization.QuantWrapper(my_model) qModel.eval() qModel.qconfig = torch.quantization.default_qconfig - torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare(qModel, inplace=True) qModel(calib_data) torch.quantization.convert(qModel, inplace=True) @@ -45,7 +45,7 @@ class TestModelNumericsEager(QuantizationTestCase): q_model = torch.quantization.QuantWrapper(my_model) q_model.eval() q_model.qconfig = torch.quantization.default_per_channel_qconfig - torch.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare(q_model) q_model(calib_data) torch.quantization.convert(q_model) @@ -67,7 +67,7 @@ class TestModelNumericsEager(QuantizationTestCase): fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = torch.quantization.default_qat_qconfig - torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) @@ -103,7 +103,7 @@ class TestModelNumericsEager(QuantizationTestCase): fq_model = torch.quantization.QuantWrapper(my_model) fq_model.train() fq_model.qconfig = qconfig - torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) diff --git a/test/quantization/eager/test_quantize_eager_ptq.py b/test/quantization/eager/test_quantize_eager_ptq.py index 10cbd92..144c209 100644 --- a/test/quantization/eager/test_quantize_eager_ptq.py +++ b/test/quantization/eager/test_quantize_eager_ptq.py @@ -1001,7 +1001,7 @@ class TestPostTrainingDynamic(QuantizationTestCase): model = LinearReluLinearModel().eval() qconfig = default_dynamic_qconfig qconfig_dict = {'' : qconfig} - torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True) + torch.ao.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True) prepare_dynamic(model, qconfig_dict) convert_dynamic(model) diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 51029f9..245188e 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,11 +1,3 @@ from .fake_quantize import * # noqa: F403 - -# TODO(future PR): fix the typo, should be `__all__` -_all__ = [ - # FakeQuantize (for qat) - 'default_fake_quant', 'default_weight_fake_quant', - 'default_symmetric_fixed_qparams_fake_quant', - 'default_affine_fixed_qparams_fake_quant', - 'default_per_channel_weight_fake_quant', - 'default_histogram_fake_quant', -] +from .fuse_modules import * # noqa: F403 +from .quantize import * # noqa: F403 diff --git a/torch/ao/quantization/fuse_modules.py b/torch/ao/quantization/fuse_modules.py new file mode 100644 index 0000000..f702bdf --- /dev/null +++ b/torch/ao/quantization/fuse_modules.py @@ -0,0 +1,147 @@ + +import copy + +import torch.nn as nn + +from torch.quantization.fuser_method_mappings import get_fuser_method +# for backward compatiblity +from torch.quantization.fuser_method_mappings import fuse_conv_bn # noqa: F401 +from torch.quantization.fuser_method_mappings import fuse_conv_bn_relu # noqa: F401 + +from typing import List, Optional + +# Generalization of getattr +def _get_module(model, submodule_key): + tokens = submodule_key.split('.') + cur_mod = model + for s in tokens: + cur_mod = getattr(cur_mod, s) + return cur_mod + +# Generalization of setattr +def _set_module(model, submodule_key, module): + tokens = submodule_key.split('.') + sub_tokens = tokens[:-1] + cur_mod = model + for s in sub_tokens: + cur_mod = getattr(cur_mod, s) + + setattr(cur_mod, tokens[-1], module) + +def fuse_known_modules(mod_list, additional_fuser_method_mapping=None): + r"""Returns a list of modules that fuses the operations specified + in the input module list. + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, bn + linear, relu + For these sequences, the first element in the output module list performs + the fused operation. The rest of the elements are set to nn.Identity() + """ + types = tuple(type(m) for m in mod_list) + fuser_method = get_fuser_method(types, additional_fuser_method_mapping) + if fuser_method is None: + raise NotImplementedError("Cannot fuse modules: {}".format(types)) + new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) + fused = fuser_method(*mod_list) + # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion + # Move pre forward hooks of the base module to resulting fused module + for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items(): + fused.register_forward_pre_hook(pre_hook_fn) + del mod_list[0]._forward_pre_hooks[handle_id] + # Move post forward hooks of the last module to resulting fused module + for handle_id, hook_fn in mod_list[-1]._forward_hooks.items(): + fused.register_forward_hook(hook_fn) + del mod_list[-1]._forward_hooks[handle_id] + new_mod[0] = fused + + for i in range(1, len(mod_list)): + identity = nn.Identity() + identity.training = mod_list[0].training + new_mod[i] = identity + + return new_mod + +def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + if fuse_custom_config_dict is None: + fuse_custom_config_dict = {} + additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) + mod_list = [] + for item in modules_to_fuse: + mod_list.append(_get_module(model, item)) + + # Fuse list of modules + new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping) + + # Replace original module list with fused module list + for i, item in enumerate(modules_to_fuse): + _set_module(model, item, new_mod_list[i]) + +def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): + r"""Fuses a list of modules into a single module + + Fuses only the following sequence of modules: + conv, bn + conv, bn, relu + conv, relu + linear, relu + bn, relu + All other sequences are left unchanged. + For these sequences, replaces the first item in the list + with the fused module, replacing the rest of the modules + with identity. + + Args: + model: Model containing the modules to be fused + modules_to_fuse: list of list of module names to fuse. Can also be a list + of strings if there is only a single list of modules to fuse. + inplace: bool specifying if fusion happens in place on the model, by default + a new model is returned + fuser_func: Function that takes in a list of modules and outputs a list of fused modules + of the same length. For example, + fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] + Defaults to torch.quantization.fuse_known_modules + `fuse_custom_config_dict`: custom configuration for fusion + + .. code-block:: python + + # Example of fuse_custom_config_dict + fuse_custom_config_dict = { + # Additional fuser_method mapping + "additional_fuser_method_mapping": { + (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn + }, + } + + Returns: + model with fused modules. A new copy is created if inplace=True. + + Examples:: + + >>> m = myModel() + >>> # m is a module containing the sub-modules below + >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + >>> m = myModel() + >>> # Alternately provide a single list of modules to fuse + >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] + >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) + >>> output = fused_m(input) + + """ + if not inplace: + model = copy.deepcopy(model) + + if all(isinstance(module_element, str) for module_element in modules_to_fuse): + # Handle case of modules_to_fuse being a list + _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict) + else: + # Handle case of modules_to_fuse being a list of lists + for module_list in modules_to_fuse: + _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict) + return model diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index d4eec00..896f357 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -1,147 +1,24 @@ +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/fuse_modules.py`, while adding an import statement +here. +""" + +from torch.ao.quantization.fuse_modules import fuse_modules +from torch.ao.quantization.fuse_modules import fuse_known_modules +from torch.ao.quantization.fuse_modules import get_fuser_method -import copy - -import torch.nn as nn - -from .fuser_method_mappings import get_fuser_method # for backward compatiblity -from .fuser_method_mappings import fuse_conv_bn # noqa: F401 -from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F401 - -from typing import List, Optional - -# Generalization of getattr -def _get_module(model, submodule_key): - tokens = submodule_key.split('.') - cur_mod = model - for s in tokens: - cur_mod = getattr(cur_mod, s) - return cur_mod - -# Generalization of setattr -def _set_module(model, submodule_key, module): - tokens = submodule_key.split('.') - sub_tokens = tokens[:-1] - cur_mod = model - for s in sub_tokens: - cur_mod = getattr(cur_mod, s) - - setattr(cur_mod, tokens[-1], module) - -def fuse_known_modules(mod_list, additional_fuser_method_mapping=None): - r"""Returns a list of modules that fuses the operations specified - in the input module list. - - Fuses only the following sequence of modules: - conv, bn - conv, bn, relu - conv, relu - linear, bn - linear, relu - For these sequences, the first element in the output module list performs - the fused operation. The rest of the elements are set to nn.Identity() - """ - types = tuple(type(m) for m in mod_list) - fuser_method = get_fuser_method(types, additional_fuser_method_mapping) - if fuser_method is None: - raise NotImplementedError("Cannot fuse modules: {}".format(types)) - new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) - fused = fuser_method(*mod_list) - # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion - # Move pre forward hooks of the base module to resulting fused module - for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items(): - fused.register_forward_pre_hook(pre_hook_fn) - del mod_list[0]._forward_pre_hooks[handle_id] - # Move post forward hooks of the last module to resulting fused module - for handle_id, hook_fn in mod_list[-1]._forward_hooks.items(): - fused.register_forward_hook(hook_fn) - del mod_list[-1]._forward_hooks[handle_id] - new_mod[0] = fused - - for i in range(1, len(mod_list)): - identity = nn.Identity() - identity.training = mod_list[0].training - new_mod[i] = identity - - return new_mod - -def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): - if fuse_custom_config_dict is None: - fuse_custom_config_dict = {} - additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) - mod_list = [] - for item in modules_to_fuse: - mod_list.append(_get_module(model, item)) - - # Fuse list of modules - new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping) - - # Replace original module list with fused module list - for i, item in enumerate(modules_to_fuse): - _set_module(model, item, new_mod_list[i]) - -def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): - r"""Fuses a list of modules into a single module - - Fuses only the following sequence of modules: - conv, bn - conv, bn, relu - conv, relu - linear, relu - bn, relu - All other sequences are left unchanged. - For these sequences, replaces the first item in the list - with the fused module, replacing the rest of the modules - with identity. - - Args: - model: Model containing the modules to be fused - modules_to_fuse: list of list of module names to fuse. Can also be a list - of strings if there is only a single list of modules to fuse. - inplace: bool specifying if fusion happens in place on the model, by default - a new model is returned - fuser_func: Function that takes in a list of modules and outputs a list of fused modules - of the same length. For example, - fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] - Defaults to torch.quantization.fuse_known_modules - `fuse_custom_config_dict`: custom configuration for fusion - - .. code-block:: python - - # Example of fuse_custom_config_dict - fuse_custom_config_dict = { - # Additional fuser_method mapping - "additional_fuser_method_mapping": { - (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn - }, - } - - Returns: - model with fused modules. A new copy is created if inplace=True. - - Examples:: - - >>> m = myModel() - >>> # m is a module containing the sub-modules below - >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] - >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse) - >>> output = fused_m(input) - - >>> m = myModel() - >>> # Alternately provide a single list of modules to fuse - >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] - >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse) - >>> output = fused_m(input) - - """ - if not inplace: - model = copy.deepcopy(model) - - if all(isinstance(module_element, str) for module_element in modules_to_fuse): - # Handle case of modules_to_fuse being a list - _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict) - else: - # Handle case of modules_to_fuse being a list of lists - for module_list in modules_to_fuse: - _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict) - return model +from torch.quantization.fuser_method_mappings import fuse_conv_bn +from torch.quantization.fuser_method_mappings import fuse_conv_bn_relu + +# TODO: These functions are not used outside the `fuse_modules.py` +# Keeping here for now, need to remove them later. +from torch.ao.quantization.fuse_modules import ( + _fuse_modules, + _get_module, + _set_module, +) -- 2.7.4