From: Zafar Takhirov Date: Thu, 16 Sep 2021 00:24:09 +0000 (-0700) Subject: [quant] AO migration of the `quant_types.py` (phase 1) (#64916) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~166 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c151d62f45d769ee43a611d0776c7325225f7a2a;p=platform%2Fupstream%2Fpytorch.git [quant] AO migration of the `quant_types.py` (phase 1) (#64916) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64916 AO Team is migrating the existing torch.quantization into torch.ao.quantization. We are doing it one file at a time to make sure that the internal callsites are updated properly. This migrates the quant_type.py from torch.quantization to torch.ao.quantization. At this point both locations will be supported. Eventually the torch.quantization will be deprecated. Test Plan: `buck test mode/dev //caffe2/test:quantization -- TestAOMigrationQuantization` Reviewed By: vkuzo Differential Revision: D30898422 fbshipit-source-id: 3e6126b49f0565a4136d6928cea9eb25368927ff --- diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index daf62f2..68ba478 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -109,3 +109,13 @@ class TestAOMigrationQuantization(AOMigrationTestCase): 'get_fuser_method', ] self._test_function_import('fuse_modules', function_list) + + def test_package_import_quant_type(self): + self._test_package_import('quant_type') + + def test_function_import_quant_type(self): + function_list = [ + 'QuantType', + 'quant_type_to_str', + ] + self._test_function_import('quant_type', function_list) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index f2f665d..182f5ad 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -24,12 +24,15 @@ from torch.quantization.fx.match_utils import ( MatchAllNode, ) -from torch.quantization import ( +from torch.ao.quantization import ( QuantType, + quant_type_to_str, +) + +from torch.quantization import ( QuantStub, DeQuantStub, QuantWrapper, - quant_type_to_str, default_qconfig, default_dynamic_qconfig, default_qat_qconfig, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 245188e..0950b84 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,3 +1,4 @@ from .fake_quantize import * # noqa: F403 from .fuse_modules import * # noqa: F403 +from .quant_type import * # noqa: F403 from .quantize import * # noqa: F403 diff --git a/torch/ao/quantization/quant_type.py b/torch/ao/quantization/quant_type.py new file mode 100644 index 0000000..463d086 --- /dev/null +++ b/torch/ao/quantization/quant_type.py @@ -0,0 +1,19 @@ +import enum + +# Quantization type (dynamic quantization, static quantization). +# Should match the c++ enum in quantization_type.h +class QuantType(enum.IntEnum): + DYNAMIC = 0 + STATIC = 1 + QAT = 2 + WEIGHT_ONLY = 3 + + +def quant_type_to_str(quant_type): + m = { + QuantType.STATIC: "static", + QuantType.DYNAMIC: "dynamic", + QuantType.QAT: "qat", + QuantType.WEIGHT_ONLY: "weight_only", + } + return m[quant_type] diff --git a/torch/quantization/quant_type.py b/torch/quantization/quant_type.py index 463d086..cd2e5e0 100644 --- a/torch/quantization/quant_type.py +++ b/torch/quantization/quant_type.py @@ -1,19 +1,11 @@ -import enum +# flake8: noqa: F401 +r""" +This file is in the process of migration to `torch/ao/quantization`, and +is kept here for compatibility while the migration process is ongoing. +If you are adding a new entry/functionality, please, add it to the +`torch/ao/quantization/quant_type.py`, while adding an import statement +here. +""" -# Quantization type (dynamic quantization, static quantization). -# Should match the c++ enum in quantization_type.h -class QuantType(enum.IntEnum): - DYNAMIC = 0 - STATIC = 1 - QAT = 2 - WEIGHT_ONLY = 3 - - -def quant_type_to_str(quant_type): - m = { - QuantType.STATIC: "static", - QuantType.DYNAMIC: "dynamic", - QuantType.QAT: "qat", - QuantType.WEIGHT_ONLY: "weight_only", - } - return m[quant_type] +from torch.ao.quantization.quant_type import QuantType +from torch.ao.quantization.quant_type import quant_type_to_str diff --git a/torch/quantization/utils.py b/torch/quantization/utils.py index 9c5198b..12ca5d9 100644 --- a/torch/quantization/utils.py +++ b/torch/quantization/utils.py @@ -4,7 +4,7 @@ Utils shared by different modes of quantization (eager/graph) import warnings import functools import torch -from .quant_type import QuantType, quant_type_to_str +from torch.ao.quantization.quant_type import QuantType, quant_type_to_str from typing import Tuple, Any def get_combined_dict(default_dict, additional_dict): diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 33e758c..315dab4 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -12,10 +12,11 @@ from torch.nn.intrinsic import _FusedModule import torch.distributed as dist from torch.testing._internal.common_utils import TestCase +from torch.ao.quantization import QuantType from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \ - get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType, quantize + get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, quantize from torch.quantization.quantization_mappings import ( get_default_dynamic_quant_module_mappings, get_default_qconfig_propagation_list,