From: Zafar Takhirov Date: Wed, 15 Sep 2021 20:11:58 +0000 (-0700) Subject: [quant][refactor] Change the structure of the ao migration tests (#64912) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~182 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=425f173f9d8cb41988ff03fff316ff8424bcd521;p=platform%2Fupstream%2Fpytorch.git [quant][refactor] Change the structure of the ao migration tests (#64912) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64912 The test naming was confusing and ambiguous. The file was changed to reflect the framework that is being migrated ("quantization" instead of "quantize"). Also, the common testing class was extracted out ghstack-source-id: 138157450 Test Plan: `buck test mode/dev //caffe2/test:quantization -- TestAOMigrationQuantization` Reviewed By: vkuzo Differential Revision: D30898214 fbshipit-source-id: 017f95995271d35bcdf6ff6a1b3974b837543e84 --- diff --git a/test/quantization/ao_migration/common.py b/test/quantization/ao_migration/common.py new file mode 100644 index 0000000..1723bb5 --- /dev/null +++ b/test/quantization/ao_migration/common.py @@ -0,0 +1,33 @@ +from torch.testing._internal.common_utils import TestCase + +import importlib +from typing import List + +class AOMigrationTestCase(TestCase): + def _test_package_import(self, package_name: str): + r"""Tests the module import by making sure that all the internals match + (except the dunder methods).""" + old_module = importlib.import_module(f'torch.quantization.{package_name}') + new_module = importlib.import_module(f'torch.ao.quantization.{package_name}') + old_module_dir = set(dir(old_module)) + new_module_dir = set(dir(new_module)) + # Remove magic modules from checking in subsets + for el in list(old_module_dir): + if el[:2] == '__' and el[-2:] == '__': + old_module_dir.remove(el) + assert (old_module_dir <= new_module_dir), \ + f"Importing {old_module} vs. {new_module} does not match: " \ + f"{old_module_dir - new_module_dir}" + + def _test_function_import(self, package_name: str, function_list: List[str]): + r"""Tests individual function list import by comparing the functions + and their hashes.""" + old_location = importlib.import_module(f'torch.quantization.{package_name}') + new_location = importlib.import_module(f'torch.ao.quantization.{package_name}') + for fn_name in function_list: + old_function = getattr(old_location, fn_name) + new_function = getattr(new_location, fn_name) + assert old_function == new_function, f"Functions don't match: {fn_name}" + assert hash(old_function) == hash(new_function), \ + f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \ + f"{new_function}({hash(new_function)})" diff --git a/test/quantization/ao_migration/test_quantize.py b/test/quantization/ao_migration/test_quantization.py similarity index 62% rename from test/quantization/ao_migration/test_quantize.py rename to test/quantization/ao_migration/test_quantization.py index d6e6109..72893d1 100644 --- a/test/quantization/ao_migration/test_quantize.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -1,40 +1,10 @@ -from torch.testing._internal.common_utils import TestCase +from .common import AOMigrationTestCase -import importlib -from typing import List - -class AOMigrationTestCase(TestCase): - def _test_package_import(self, package_name: str): - r"""Tests the module import by making sure that all the internals match - (except the dunder methods).""" - old_module = importlib.import_module(f'torch.quantization.{package_name}') - new_module = importlib.import_module(f'torch.ao.quantization.{package_name}') - old_module_dir = set(dir(old_module)) - new_module_dir = set(dir(new_module)) - # Remove magic modules from checking in subsets - for el in list(old_module_dir): - if el[:2] == '__' and el[-2:] == '__': - old_module_dir.remove(el) - assert (old_module_dir <= new_module_dir), \ - f"Importing {old_module} vs. {new_module} does not match: " \ - f"{old_module_dir - new_module_dir}" - - def _test_function_import(self, package_name: str, function_list: List[str]): - r"""Tests individual function list import by comparing the functions - and their hashes.""" - old_location = importlib.import_module(f'torch.quantization.{package_name}') - new_location = importlib.import_module(f'torch.ao.quantization.{package_name}') - for fn_name in function_list: - old_function = getattr(old_location, fn_name) - new_function = getattr(new_location, fn_name) - assert old_function == new_function, f"Functions don't match: {fn_name}" - assert hash(old_function) == hash(new_function), \ - f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \ - f"{new_function}({hash(new_function)})" - - -class TestAOMigrationQuantizePy(AOMigrationTestCase): +class TestAOMigrationQuantization(AOMigrationTestCase): + r"""Modules and functions related to the + `torch/quantization` migration to `torch/ao/quantization`. + """ def test_package_import_quantize(self): self._test_package_import('quantize') diff --git a/test/test_quantization.py b/test/test_quantization.py index 7fd4e50..cddfaec 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -101,7 +101,7 @@ from quantization.jit.test_fusion_passes import TestFusionPasses # noqa: F401 from quantization.jit.test_deprecated_jit_quant import TestDeprecatedJitQuantized # noqa: F401 # AO Migration tests -from quantization.ao_migration.test_quantize import TestAOMigrationQuantizePy # noqa: F401 +from quantization.ao_migration.test_quantization import TestAOMigrationQuantization # noqa: F401 if __name__ == '__main__':