[quant][refactor] Change the structure of the ao migration tests (#64912)
authorZafar Takhirov <zaf@fb.com>
Wed, 15 Sep 2021 20:11:58 +0000 (13:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 20:15:43 +0000 (13:15 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64912

The test naming was confusing and ambiguous. The file was changed to reflect the framework that is being migrated ("quantization" instead of "quantize"). Also, the common testing class was extracted out
ghstack-source-id: 138157450

Test Plan: `buck test mode/dev //caffe2/test:quantization -- TestAOMigrationQuantization`

Reviewed By: vkuzo

Differential Revision: D30898214

fbshipit-source-id: 017f95995271d35bcdf6ff6a1b3974b837543e84

test/quantization/ao_migration/common.py [new file with mode: 0644]
test/quantization/ao_migration/test_quantization.py [moved from test/quantization/ao_migration/test_quantize.py with 62% similarity]
test/test_quantization.py

diff --git a/test/quantization/ao_migration/common.py b/test/quantization/ao_migration/common.py
new file mode 100644 (file)
index 0000000..1723bb5
--- /dev/null
@@ -0,0 +1,33 @@
+from torch.testing._internal.common_utils import TestCase
+
+import importlib
+from typing import List
+
+class AOMigrationTestCase(TestCase):
+    def _test_package_import(self, package_name: str):
+        r"""Tests the module import by making sure that all the internals match
+        (except the dunder methods)."""
+        old_module = importlib.import_module(f'torch.quantization.{package_name}')
+        new_module = importlib.import_module(f'torch.ao.quantization.{package_name}')
+        old_module_dir = set(dir(old_module))
+        new_module_dir = set(dir(new_module))
+        # Remove magic modules from checking in subsets
+        for el in list(old_module_dir):
+            if el[:2] == '__' and el[-2:] == '__':
+                old_module_dir.remove(el)
+        assert (old_module_dir <= new_module_dir), \
+            f"Importing {old_module} vs. {new_module} does not match: " \
+            f"{old_module_dir - new_module_dir}"
+
+    def _test_function_import(self, package_name: str, function_list: List[str]):
+        r"""Tests individual function list import by comparing the functions
+        and their hashes."""
+        old_location = importlib.import_module(f'torch.quantization.{package_name}')
+        new_location = importlib.import_module(f'torch.ao.quantization.{package_name}')
+        for fn_name in function_list:
+            old_function = getattr(old_location, fn_name)
+            new_function = getattr(new_location, fn_name)
+            assert old_function == new_function, f"Functions don't match: {fn_name}"
+            assert hash(old_function) == hash(new_function), \
+                f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
+                f"{new_function}({hash(new_function)})"
@@ -1,40 +1,10 @@
-from torch.testing._internal.common_utils import TestCase
+from .common import AOMigrationTestCase
 
-import importlib
-from typing import List
 
-
-class AOMigrationTestCase(TestCase):
-    def _test_package_import(self, package_name: str):
-        r"""Tests the module import by making sure that all the internals match
-        (except the dunder methods)."""
-        old_module = importlib.import_module(f'torch.quantization.{package_name}')
-        new_module = importlib.import_module(f'torch.ao.quantization.{package_name}')
-        old_module_dir = set(dir(old_module))
-        new_module_dir = set(dir(new_module))
-        # Remove magic modules from checking in subsets
-        for el in list(old_module_dir):
-            if el[:2] == '__' and el[-2:] == '__':
-                old_module_dir.remove(el)
-        assert (old_module_dir <= new_module_dir), \
-            f"Importing {old_module} vs. {new_module} does not match: " \
-            f"{old_module_dir - new_module_dir}"
-
-    def _test_function_import(self, package_name: str, function_list: List[str]):
-        r"""Tests individual function list import by comparing the functions
-        and their hashes."""
-        old_location = importlib.import_module(f'torch.quantization.{package_name}')
-        new_location = importlib.import_module(f'torch.ao.quantization.{package_name}')
-        for fn_name in function_list:
-            old_function = getattr(old_location, fn_name)
-            new_function = getattr(new_location, fn_name)
-            assert old_function == new_function, f"Functions don't match: {fn_name}"
-            assert hash(old_function) == hash(new_function), \
-                f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
-                f"{new_function}({hash(new_function)})"
-
-
-class TestAOMigrationQuantizePy(AOMigrationTestCase):
+class TestAOMigrationQuantization(AOMigrationTestCase):
+    r"""Modules and functions related to the
+    `torch/quantization` migration to `torch/ao/quantization`.
+    """
     def test_package_import_quantize(self):
         self._test_package_import('quantize')
 
index 7fd4e50..cddfaec 100644 (file)
@@ -101,7 +101,7 @@ from quantization.jit.test_fusion_passes import TestFusionPasses  # noqa: F401
 from quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized  # noqa: F401
 
 # AO Migration tests
-from quantization.ao_migration.test_quantize import TestAOMigrationQuantizePy  # noqa: F401
+from quantization.ao_migration.test_quantization import TestAOMigrationQuantization  # noqa: F401
 
 
 if __name__ == '__main__':