'enable_observer',
]
self._test_function_import('fake_quantize', function_list)
+
+ def test_package_import_fuse_modules(self):
+ self._test_package_import('fuse_modules')
+
+ def test_function_import_fuse_modules(self):
+ function_list = [
+ '_fuse_modules',
+ '_get_module',
+ '_set_module',
+ 'fuse_conv_bn',
+ 'fuse_conv_bn_relu',
+ 'fuse_known_modules',
+ 'fuse_modules',
+ 'get_fuser_method',
+ ]
+ self._test_function_import('fuse_modules', function_list)
torch.quantization.DeQuantStub(),
)
- torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
+ torch.ao.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
model = Model()
# fuse it
- fused_model = torch.quantization.fuse_modules(
+ fused_model = torch.ao.quantization.fuse_modules(
model,
[['conv', 'bn']],
)
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
-from torch.quantization.fuse_modules import fuse_modules
+from torch.ao.quantization.fuse_modules import fuse_modules
import torch.quantization._equalize as _equalize
qModel = torch.quantization.QuantWrapper(my_model)
qModel.eval()
qModel.qconfig = torch.quantization.default_qconfig
- torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+ torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare(qModel, inplace=True)
qModel(calib_data)
torch.quantization.convert(qModel, inplace=True)
q_model = torch.quantization.QuantWrapper(my_model)
q_model.eval()
q_model.qconfig = torch.quantization.default_per_channel_qconfig
- torch.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+ torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare(q_model)
q_model(calib_data)
torch.quantization.convert(q_model)
fq_model = torch.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = torch.quantization.default_qat_qconfig
- torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+ torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.quantization.disable_fake_quant)
fq_model = torch.quantization.QuantWrapper(my_model)
fq_model.train()
fq_model.qconfig = qconfig
- torch.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
+ torch.ao.quantization.fuse_modules(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
torch.quantization.prepare_qat(fq_model)
fq_model.eval()
fq_model.apply(torch.quantization.disable_fake_quant)
model = LinearReluLinearModel().eval()
qconfig = default_dynamic_qconfig
qconfig_dict = {'' : qconfig}
- torch.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
+ torch.ao.quantization.fuse_modules(model, [['fc1', 'relu']], inplace=True)
prepare_dynamic(model, qconfig_dict)
convert_dynamic(model)
from .fake_quantize import * # noqa: F403
-
-# TODO(future PR): fix the typo, should be `__all__`
-_all__ = [
- # FakeQuantize (for qat)
- 'default_fake_quant', 'default_weight_fake_quant',
- 'default_symmetric_fixed_qparams_fake_quant',
- 'default_affine_fixed_qparams_fake_quant',
- 'default_per_channel_weight_fake_quant',
- 'default_histogram_fake_quant',
-]
+from .fuse_modules import * # noqa: F403
+from .quantize import * # noqa: F403
--- /dev/null
+
+import copy
+
+import torch.nn as nn
+
+from torch.quantization.fuser_method_mappings import get_fuser_method
+# for backward compatiblity
+from torch.quantization.fuser_method_mappings import fuse_conv_bn # noqa: F401
+from torch.quantization.fuser_method_mappings import fuse_conv_bn_relu # noqa: F401
+
+from typing import List, Optional
+
+# Generalization of getattr
+def _get_module(model, submodule_key):
+ tokens = submodule_key.split('.')
+ cur_mod = model
+ for s in tokens:
+ cur_mod = getattr(cur_mod, s)
+ return cur_mod
+
+# Generalization of setattr
+def _set_module(model, submodule_key, module):
+ tokens = submodule_key.split('.')
+ sub_tokens = tokens[:-1]
+ cur_mod = model
+ for s in sub_tokens:
+ cur_mod = getattr(cur_mod, s)
+
+ setattr(cur_mod, tokens[-1], module)
+
+def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
+ r"""Returns a list of modules that fuses the operations specified
+ in the input module list.
+
+ Fuses only the following sequence of modules:
+ conv, bn
+ conv, bn, relu
+ conv, relu
+ linear, bn
+ linear, relu
+ For these sequences, the first element in the output module list performs
+ the fused operation. The rest of the elements are set to nn.Identity()
+ """
+ types = tuple(type(m) for m in mod_list)
+ fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
+ if fuser_method is None:
+ raise NotImplementedError("Cannot fuse modules: {}".format(types))
+ new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
+ fused = fuser_method(*mod_list)
+ # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
+ # Move pre forward hooks of the base module to resulting fused module
+ for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
+ fused.register_forward_pre_hook(pre_hook_fn)
+ del mod_list[0]._forward_pre_hooks[handle_id]
+ # Move post forward hooks of the last module to resulting fused module
+ for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
+ fused.register_forward_hook(hook_fn)
+ del mod_list[-1]._forward_hooks[handle_id]
+ new_mod[0] = fused
+
+ for i in range(1, len(mod_list)):
+ identity = nn.Identity()
+ identity.training = mod_list[0].training
+ new_mod[i] = identity
+
+ return new_mod
+
+def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+ if fuse_custom_config_dict is None:
+ fuse_custom_config_dict = {}
+ additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
+ mod_list = []
+ for item in modules_to_fuse:
+ mod_list.append(_get_module(model, item))
+
+ # Fuse list of modules
+ new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
+
+ # Replace original module list with fused module list
+ for i, item in enumerate(modules_to_fuse):
+ _set_module(model, item, new_mod_list[i])
+
+def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
+ r"""Fuses a list of modules into a single module
+
+ Fuses only the following sequence of modules:
+ conv, bn
+ conv, bn, relu
+ conv, relu
+ linear, relu
+ bn, relu
+ All other sequences are left unchanged.
+ For these sequences, replaces the first item in the list
+ with the fused module, replacing the rest of the modules
+ with identity.
+
+ Args:
+ model: Model containing the modules to be fused
+ modules_to_fuse: list of list of module names to fuse. Can also be a list
+ of strings if there is only a single list of modules to fuse.
+ inplace: bool specifying if fusion happens in place on the model, by default
+ a new model is returned
+ fuser_func: Function that takes in a list of modules and outputs a list of fused modules
+ of the same length. For example,
+ fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
+ Defaults to torch.quantization.fuse_known_modules
+ `fuse_custom_config_dict`: custom configuration for fusion
+
+ .. code-block:: python
+
+ # Example of fuse_custom_config_dict
+ fuse_custom_config_dict = {
+ # Additional fuser_method mapping
+ "additional_fuser_method_mapping": {
+ (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
+ },
+ }
+
+ Returns:
+ model with fused modules. A new copy is created if inplace=True.
+
+ Examples::
+
+ >>> m = myModel()
+ >>> # m is a module containing the sub-modules below
+ >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
+ >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
+ >>> output = fused_m(input)
+
+ >>> m = myModel()
+ >>> # Alternately provide a single list of modules to fuse
+ >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
+ >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
+ >>> output = fused_m(input)
+
+ """
+ if not inplace:
+ model = copy.deepcopy(model)
+
+ if all(isinstance(module_element, str) for module_element in modules_to_fuse):
+ # Handle case of modules_to_fuse being a list
+ _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
+ else:
+ # Handle case of modules_to_fuse being a list of lists
+ for module_list in modules_to_fuse:
+ _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
+ return model
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/quantization/fuse_modules.py`, while adding an import statement
+here.
+"""
+
+from torch.ao.quantization.fuse_modules import fuse_modules
+from torch.ao.quantization.fuse_modules import fuse_known_modules
+from torch.ao.quantization.fuse_modules import get_fuser_method
-import copy
-
-import torch.nn as nn
-
-from .fuser_method_mappings import get_fuser_method
# for backward compatiblity
-from .fuser_method_mappings import fuse_conv_bn # noqa: F401
-from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F401
-
-from typing import List, Optional
-
-# Generalization of getattr
-def _get_module(model, submodule_key):
- tokens = submodule_key.split('.')
- cur_mod = model
- for s in tokens:
- cur_mod = getattr(cur_mod, s)
- return cur_mod
-
-# Generalization of setattr
-def _set_module(model, submodule_key, module):
- tokens = submodule_key.split('.')
- sub_tokens = tokens[:-1]
- cur_mod = model
- for s in sub_tokens:
- cur_mod = getattr(cur_mod, s)
-
- setattr(cur_mod, tokens[-1], module)
-
-def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
- r"""Returns a list of modules that fuses the operations specified
- in the input module list.
-
- Fuses only the following sequence of modules:
- conv, bn
- conv, bn, relu
- conv, relu
- linear, bn
- linear, relu
- For these sequences, the first element in the output module list performs
- the fused operation. The rest of the elements are set to nn.Identity()
- """
- types = tuple(type(m) for m in mod_list)
- fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
- if fuser_method is None:
- raise NotImplementedError("Cannot fuse modules: {}".format(types))
- new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
- fused = fuser_method(*mod_list)
- # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
- # Move pre forward hooks of the base module to resulting fused module
- for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
- fused.register_forward_pre_hook(pre_hook_fn)
- del mod_list[0]._forward_pre_hooks[handle_id]
- # Move post forward hooks of the last module to resulting fused module
- for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
- fused.register_forward_hook(hook_fn)
- del mod_list[-1]._forward_hooks[handle_id]
- new_mod[0] = fused
-
- for i in range(1, len(mod_list)):
- identity = nn.Identity()
- identity.training = mod_list[0].training
- new_mod[i] = identity
-
- return new_mod
-
-def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
- if fuse_custom_config_dict is None:
- fuse_custom_config_dict = {}
- additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
- mod_list = []
- for item in modules_to_fuse:
- mod_list.append(_get_module(model, item))
-
- # Fuse list of modules
- new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
-
- # Replace original module list with fused module list
- for i, item in enumerate(modules_to_fuse):
- _set_module(model, item, new_mod_list[i])
-
-def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
- r"""Fuses a list of modules into a single module
-
- Fuses only the following sequence of modules:
- conv, bn
- conv, bn, relu
- conv, relu
- linear, relu
- bn, relu
- All other sequences are left unchanged.
- For these sequences, replaces the first item in the list
- with the fused module, replacing the rest of the modules
- with identity.
-
- Args:
- model: Model containing the modules to be fused
- modules_to_fuse: list of list of module names to fuse. Can also be a list
- of strings if there is only a single list of modules to fuse.
- inplace: bool specifying if fusion happens in place on the model, by default
- a new model is returned
- fuser_func: Function that takes in a list of modules and outputs a list of fused modules
- of the same length. For example,
- fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
- Defaults to torch.quantization.fuse_known_modules
- `fuse_custom_config_dict`: custom configuration for fusion
-
- .. code-block:: python
-
- # Example of fuse_custom_config_dict
- fuse_custom_config_dict = {
- # Additional fuser_method mapping
- "additional_fuser_method_mapping": {
- (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
- },
- }
-
- Returns:
- model with fused modules. A new copy is created if inplace=True.
-
- Examples::
-
- >>> m = myModel()
- >>> # m is a module containing the sub-modules below
- >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
- >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
- >>> output = fused_m(input)
-
- >>> m = myModel()
- >>> # Alternately provide a single list of modules to fuse
- >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
- >>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
- >>> output = fused_m(input)
-
- """
- if not inplace:
- model = copy.deepcopy(model)
-
- if all(isinstance(module_element, str) for module_element in modules_to_fuse):
- # Handle case of modules_to_fuse being a list
- _fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
- else:
- # Handle case of modules_to_fuse being a list of lists
- for module_list in modules_to_fuse:
- _fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
- return model
+from torch.quantization.fuser_method_mappings import fuse_conv_bn
+from torch.quantization.fuser_method_mappings import fuse_conv_bn_relu
+
+# TODO: These functions are not used outside the `fuse_modules.py`
+# Keeping here for now, need to remove them later.
+from torch.ao.quantization.fuse_modules import (
+ _fuse_modules,
+ _get_module,
+ _set_module,
+)