[quant] AO migration of the `quant_types.py` (phase 1) (#64916)
authorZafar Takhirov <zaf@fb.com>
Thu, 16 Sep 2021 00:24:09 +0000 (17:24 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 00:30:00 +0000 (17:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64916

AO Team is migrating the existing torch.quantization into torch.ao.quantization. We are doing it one file at a time to make sure that the internal callsites are updated properly.
This migrates the quant_type.py from torch.quantization to torch.ao.quantization.
At this point both locations will be supported. Eventually the torch.quantization will be deprecated.

Test Plan: `buck test mode/dev //caffe2/test:quantization -- TestAOMigrationQuantization`

Reviewed By: vkuzo

Differential Revision: D30898422

fbshipit-source-id: 3e6126b49f0565a4136d6928cea9eb25368927ff

test/quantization/ao_migration/test_quantization.py
test/quantization/fx/test_quantize_fx.py
torch/ao/quantization/__init__.py
torch/ao/quantization/quant_type.py [new file with mode: 0644]
torch/quantization/quant_type.py
torch/quantization/utils.py
torch/testing/_internal/common_quantization.py

index daf62f2..68ba478 100644 (file)
@@ -109,3 +109,13 @@ class TestAOMigrationQuantization(AOMigrationTestCase):
             'get_fuser_method',
         ]
         self._test_function_import('fuse_modules', function_list)
+
+    def test_package_import_quant_type(self):
+        self._test_package_import('quant_type')
+
+    def test_function_import_quant_type(self):
+        function_list = [
+            'QuantType',
+            'quant_type_to_str',
+        ]
+        self._test_function_import('quant_type', function_list)
index f2f665d..182f5ad 100644 (file)
@@ -24,12 +24,15 @@ from torch.quantization.fx.match_utils import (
     MatchAllNode,
 )
 
-from torch.quantization import (
+from torch.ao.quantization import (
     QuantType,
+    quant_type_to_str,
+)
+
+from torch.quantization import (
     QuantStub,
     DeQuantStub,
     QuantWrapper,
-    quant_type_to_str,
     default_qconfig,
     default_dynamic_qconfig,
     default_qat_qconfig,
index 245188e..0950b84 100644 (file)
@@ -1,3 +1,4 @@
 from .fake_quantize import *  # noqa: F403
 from .fuse_modules import *  # noqa: F403
+from .quant_type import *  # noqa: F403
 from .quantize import *  # noqa: F403
diff --git a/torch/ao/quantization/quant_type.py b/torch/ao/quantization/quant_type.py
new file mode 100644 (file)
index 0000000..463d086
--- /dev/null
@@ -0,0 +1,19 @@
+import enum
+
+# Quantization type (dynamic quantization, static quantization).
+# Should match the c++ enum in quantization_type.h
+class QuantType(enum.IntEnum):
+    DYNAMIC = 0
+    STATIC = 1
+    QAT = 2
+    WEIGHT_ONLY = 3
+
+
+def quant_type_to_str(quant_type):
+    m = {
+        QuantType.STATIC: "static",
+        QuantType.DYNAMIC: "dynamic",
+        QuantType.QAT: "qat",
+        QuantType.WEIGHT_ONLY: "weight_only",
+    }
+    return m[quant_type]
index 463d086..cd2e5e0 100644 (file)
@@ -1,19 +1,11 @@
-import enum
+# flake8: noqa: F401
+r"""
+This file is in the process of migration to `torch/ao/quantization`, and
+is kept here for compatibility while the migration process is ongoing.
+If you are adding a new entry/functionality, please, add it to the
+`torch/ao/quantization/quant_type.py`, while adding an import statement
+here.
+"""
 
-# Quantization type (dynamic quantization, static quantization).
-# Should match the c++ enum in quantization_type.h
-class QuantType(enum.IntEnum):
-    DYNAMIC = 0
-    STATIC = 1
-    QAT = 2
-    WEIGHT_ONLY = 3
-
-
-def quant_type_to_str(quant_type):
-    m = {
-        QuantType.STATIC: "static",
-        QuantType.DYNAMIC: "dynamic",
-        QuantType.QAT: "qat",
-        QuantType.WEIGHT_ONLY: "weight_only",
-    }
-    return m[quant_type]
+from torch.ao.quantization.quant_type import QuantType
+from torch.ao.quantization.quant_type import quant_type_to_str
index 9c5198b..12ca5d9 100644 (file)
@@ -4,7 +4,7 @@ Utils shared by different modes of quantization (eager/graph)
 import warnings
 import functools
 import torch
-from .quant_type import QuantType, quant_type_to_str
+from torch.ao.quantization.quant_type import QuantType, quant_type_to_str
 from typing import Tuple, Any
 
 def get_combined_dict(default_dict, additional_dict):
index 33e758c..315dab4 100644 (file)
@@ -12,10 +12,11 @@ from torch.nn.intrinsic import _FusedModule
 import torch.distributed as dist
 
 from torch.testing._internal.common_utils import TestCase
+from torch.ao.quantization import QuantType
 from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \
     default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
     propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \
-    get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, QuantType, quantize
+    get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, QConfigDynamic, quantize
 from torch.quantization.quantization_mappings import (
     get_default_dynamic_quant_module_mappings,
     get_default_qconfig_propagation_list,