From b7ec7d760d1120683fef1a0ad6c03ddf4b8b7d0c Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Tue, 14 Sep 2021 19:51:32 -0700 Subject: [PATCH] Generic test parametrization functionality (#60753) Summary: This PR plays around with implementation & usage of a `parametrize` decorator for test parametrization similar to `pytest.mark.parametrize`, based on previous work introducing a `_TestParametrizer` class. It works with the internal `DeviceTest` hierarchy & composes with `dtype`, `skip*`, and other decorators. Basic usage is demonstrated in `test/test_blah.py`: ```python import unittest from itertools import product from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, deviceCountAtLeast, ops) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( TestCase, run_tests, parametrize, instantiate_parametrized_tests, subtest) class TestBlah(TestCase): parametrize("x", range(5)) def test_default_names(self, x): print('Passed in:', x) # Use default names but add an expected failure. parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]), *range(1, 5)]) def test_default_names_expected_failure(self, x): if x == 0: raise RuntimeError('Boom') print('Passed in:', x) parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias') def test_custom_names(self, bias): print('Passed in:', bias) parametrize("bias", [subtest(True, name='bias'), subtest(False, name='no_bias')]) def test_custom_names_alternate(self, bias): print('Passed in:', bias) parametrize("x,y", [(1, 2), (1, 3), (1, 4)]) def test_two_things_default_names(self, x, y): print('Passed in:', x, y) parametrize("x", [1, 2, 3]) parametrize("y", [4, 5, 6]) def test_two_things_composition(self, x, y): print('Passed in:', x, y) parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]), *range(1, 3)]) parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])]) def test_two_things_composition_expected_failure(self, x, y): if x == 0 or y == 6: raise RuntimeError('Boom') print('Passed in:', x, y) parametrize("x", [1, 2]) parametrize("y", [3, 4]) parametrize("z", [5, 6]) def test_three_things_composition(self, x, y, z): print('Passed in:', x, y, z) parametrize("x", [1, 2], name_fn=str) parametrize("y", [3, 4], name_fn=str) parametrize("z", [5, 6], name_fn=str) def test_three_things_composition_custom_names(self, x, y, z): print('Passed in:', x, y, z) parametrize("x,y", product(range(2), range(3))) def test_two_things_product(self, x, y): print('Passed in:', x, y) parametrize("x,y", [subtest((1, 2), name='double'), subtest((1, 3), name='triple'), subtest((1, 4), name='quadruple')]) def test_two_things_custom_names(self, x, y): print('Passed in:', x, y) parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}_{}'.format(x, y)) def test_two_things_custom_names_alternate(self, x, y): print('Passed in:', x, y) class TestDeviceBlah(TestCase): parametrize("x", range(10)) def test_default_names(self, device, x): print('Passed in:', device, x) parametrize("x,y", [(1, 2), (3, 4), (5, 6)]) def test_two_things(self, device, x, y): print('Passed in:', device, x, y) deviceCountAtLeast(1) def test_multiple_devices(self, devices): print('Passed in:', devices) ops(op_db) parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled') def test_op_parametrized(self, device, dtype, op, flag): print('Passed in:', device, dtype, op, flag) instantiate_parametrized_tests(TestBlah) instantiate_device_type_tests(TestDeviceBlah, globals()) if __name__ == '__main__': run_tests() ``` Generated tests: ``` TestBlah.test_custom_names_alternate_bias TestBlah.test_custom_names_alternate_no_bias TestBlah.test_custom_names_bias TestBlah.test_custom_names_no_bias TestBlah.test_default_names_expected_failure_x_0 TestBlah.test_default_names_expected_failure_x_1 TestBlah.test_default_names_expected_failure_x_2 TestBlah.test_default_names_expected_failure_x_3 TestBlah.test_default_names_expected_failure_x_4 TestBlah.test_default_names_x_0 TestBlah.test_default_names_x_1 TestBlah.test_default_names_x_2 TestBlah.test_default_names_x_3 TestBlah.test_default_names_x_4 TestBlah.test_three_things_composition_custom_names_1_3_5 TestBlah.test_three_things_composition_custom_names_1_3_6 TestBlah.test_three_things_composition_custom_names_1_4_5 TestBlah.test_three_things_composition_custom_names_1_4_6 TestBlah.test_three_things_composition_custom_names_2_3_5 TestBlah.test_three_things_composition_custom_names_2_3_6 TestBlah.test_three_things_composition_custom_names_2_4_5 TestBlah.test_three_things_composition_custom_names_2_4_6 TestBlah.test_three_things_composition_x_1_y_3_z_5 TestBlah.test_three_things_composition_x_1_y_3_z_6 TestBlah.test_three_things_composition_x_1_y_4_z_5 TestBlah.test_three_things_composition_x_1_y_4_z_6 TestBlah.test_three_things_composition_x_2_y_3_z_5 TestBlah.test_three_things_composition_x_2_y_3_z_6 TestBlah.test_three_things_composition_x_2_y_4_z_5 TestBlah.test_three_things_composition_x_2_y_4_z_6 TestBlah.test_two_things_composition_expected_failure_x_0_y_4 TestBlah.test_two_things_composition_expected_failure_x_0_y_5 TestBlah.test_two_things_composition_expected_failure_x_0_y_6 TestBlah.test_two_things_composition_expected_failure_x_1_y_4 TestBlah.test_two_things_composition_expected_failure_x_1_y_5 TestBlah.test_two_things_composition_expected_failure_x_1_y_6 TestBlah.test_two_things_composition_expected_failure_x_2_y_4 TestBlah.test_two_things_composition_expected_failure_x_2_y_5 TestBlah.test_two_things_composition_expected_failure_x_2_y_6 TestBlah.test_two_things_composition_x_1_y_4 TestBlah.test_two_things_composition_x_1_y_5 TestBlah.test_two_things_composition_x_1_y_6 TestBlah.test_two_things_composition_x_2_y_4 TestBlah.test_two_things_composition_x_2_y_5 TestBlah.test_two_things_composition_x_2_y_6 TestBlah.test_two_things_composition_x_3_y_4 TestBlah.test_two_things_composition_x_3_y_5 TestBlah.test_two_things_composition_x_3_y_6 TestBlah.test_two_things_custom_names_alternate_1_2 TestBlah.test_two_things_custom_names_alternate_1_3 TestBlah.test_two_things_custom_names_alternate_1_4 TestBlah.test_two_things_custom_names_double TestBlah.test_two_things_custom_names_quadruple TestBlah.test_two_things_custom_names_triple TestBlah.test_two_things_default_names_x_1_y_2 TestBlah.test_two_things_default_names_x_1_y_3 TestBlah.test_two_things_default_names_x_1_y_4 TestBlah.test_two_things_product_x_0_y_0 TestBlah.test_two_things_product_x_0_y_1 TestBlah.test_two_things_product_x_0_y_2 TestBlah.test_two_things_product_x_1_y_0 TestBlah.test_two_things_product_x_1_y_1 TestBlah.test_two_things_product_x_1_y_2 TestDeviceBlahCPU.test_default_names_x_0_cpu TestDeviceBlahCPU.test_default_names_x_1_cpu TestDeviceBlahCPU.test_default_names_x_2_cpu TestDeviceBlahCPU.test_default_names_x_3_cpu TestDeviceBlahCPU.test_default_names_x_4_cpu TestDeviceBlahCPU.test_default_names_x_5_cpu TestDeviceBlahCPU.test_default_names_x_6_cpu TestDeviceBlahCPU.test_default_names_x_7_cpu TestDeviceBlahCPU.test_default_names_x_8_cpu TestDeviceBlahCPU.test_default_names_x_9_cpu TestDeviceBlahCPU.test_multiple_devices_cpu TestDeviceBlahCPU.test_op_parametrized___cpu_uint8_flag_enabled_cpu TestDeviceBlahCPU.test_two_things_x_1_y_2_cpu TestDeviceBlahCPU.test_two_things_x_3_y_4_cpu TestDeviceBlahCPU.test_two_things_x_5_y_6_cpu TestDeviceBlahMETA.test_default_names_x_0_meta TestDeviceBlahMETA.test_default_names_x_1_meta TestDeviceBlahMETA.test_default_names_x_2_meta TestDeviceBlahMETA.test_default_names_x_3_meta TestDeviceBlahMETA.test_default_names_x_4_meta TestDeviceBlahMETA.test_default_names_x_5_meta TestDeviceBlahMETA.test_default_names_x_6_meta TestDeviceBlahMETA.test_default_names_x_7_meta TestDeviceBlahMETA.test_default_names_x_8_meta TestDeviceBlahMETA.test_default_names_x_9_meta TestDeviceBlahMETA.test_multiple_devices_meta TestDeviceBlahMETA.test_op_parametrized___meta_uint8_flag_enabled_meta TestDeviceBlahMETA.test_two_things_x_1_y_2_meta TestDeviceBlahMETA.test_two_things_x_3_y_4_meta TestDeviceBlahMETA.test_two_things_x_5_y_6_meta ``` Caveats: * `parametrize` decorators cannot be "stacked" yet; each one overwrites the previous. This will change to either: * Allow stacking of multiple decorators * Error out with a nice error message if multiple decorators are specified The PR introduces `instantiate_parametrized_tests()` in addition to `instantiate_device_type_tests()`. The former should be used for non-device-specific tests, and the latter should be used for device-specific tests, as usual. Both of these support the `parametrize` decorator. Only the latter supports the `ops` decorator (no change here- this was already the case). Pull Request resolved: https://github.com/pytorch/pytorch/pull/60753 Reviewed By: saketh-are Differential Revision: D30606615 Pulled By: jbschlosser fbshipit-source-id: a34f36d643f68a6e221f419d9bb3e1ae1d84dd65 --- test/test_testing.py | 312 ++++++++++++++++++++- torch/testing/_internal/common_device_type.py | 94 ++----- .../_internal/common_methods_invocations.py | 6 + torch/testing/_internal/common_modules.py | 12 +- torch/testing/_internal/common_utils.py | 276 +++++++++++++++++- 5 files changed, 627 insertions(+), 73 deletions(-) diff --git a/test/test_testing.py b/test/test_testing.py index e45977f..d777587 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -12,11 +12,12 @@ import torch from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest) + (IS_FBCODE, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, skipIfRocm, slowTest, + parametrize, subtest, instantiate_parametrized_tests, dtype_name) from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, get_device_type_test_bases, instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA, - deviceCountAtLeast) + deviceCountAtLeast, ops) from torch.testing._internal.common_methods_invocations import op_db import torch.testing._internal.opinfo_helper as opinfo_helper from torch.testing._internal.common_dtype import get_all_dtypes @@ -1425,5 +1426,312 @@ class TestAssertCloseQuantized(TestCase): fn() +def _get_test_names_for_test_class(test_cls): + """ Convenience function to get all test names for a given test class. """ + test_names = ['{}.{}'.format(test_cls.__name__, key) for key in test_cls.__dict__ + if key.startswith('test_')] + return sorted(test_names) + + +class TestTestParametrization(TestCase): + def test_default_names(self): + + class TestParametrized(TestCase): + @parametrize("x", range(5)) + def test_default_names(self, x): + pass + + @parametrize("x,y", [(1, 2), (2, 3), (3, 4)]) + def test_two_things_default_names(self, x, y): + pass + + instantiate_parametrized_tests(TestParametrized) + + expected_test_names = [ + 'TestParametrized.test_default_names_x_0', + 'TestParametrized.test_default_names_x_1', + 'TestParametrized.test_default_names_x_2', + 'TestParametrized.test_default_names_x_3', + 'TestParametrized.test_default_names_x_4', + 'TestParametrized.test_two_things_default_names_x_1_y_2', + 'TestParametrized.test_two_things_default_names_x_2_y_3', + 'TestParametrized.test_two_things_default_names_x_3_y_4', + ] + test_names = _get_test_names_for_test_class(TestParametrized) + self.assertEqual(expected_test_names, test_names) + + def test_name_fn(self): + + class TestParametrized(TestCase): + @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias') + def test_custom_names(self, bias): + pass + + @parametrize("x", [1, 2], name_fn=str) + @parametrize("y", [3, 4], name_fn=str) + @parametrize("z", [5, 6], name_fn=str) + def test_three_things_composition_custom_names(self, x, y, z): + pass + + @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}__{}'.format(x, y)) + def test_two_things_custom_names_alternate(self, x, y): + pass + + instantiate_parametrized_tests(TestParametrized) + + expected_test_names = [ + 'TestParametrized.test_custom_names_bias', + 'TestParametrized.test_custom_names_no_bias', + 'TestParametrized.test_three_things_composition_custom_names_1_3_5', + 'TestParametrized.test_three_things_composition_custom_names_1_3_6', + 'TestParametrized.test_three_things_composition_custom_names_1_4_5', + 'TestParametrized.test_three_things_composition_custom_names_1_4_6', + 'TestParametrized.test_three_things_composition_custom_names_2_3_5', + 'TestParametrized.test_three_things_composition_custom_names_2_3_6', + 'TestParametrized.test_three_things_composition_custom_names_2_4_5', + 'TestParametrized.test_three_things_composition_custom_names_2_4_6', + 'TestParametrized.test_two_things_custom_names_alternate_1__2', + 'TestParametrized.test_two_things_custom_names_alternate_1__3', + 'TestParametrized.test_two_things_custom_names_alternate_1__4', + ] + test_names = _get_test_names_for_test_class(TestParametrized) + self.assertEqual(expected_test_names, test_names) + + def test_subtest_names(self): + + class TestParametrized(TestCase): + @parametrize("bias", [subtest(True, name='bias'), + subtest(False, name='no_bias')]) + def test_custom_names(self, bias): + pass + + @parametrize("x,y", [subtest((1, 2), name='double'), + subtest((1, 3), name='triple'), + subtest((1, 4), name='quadruple')]) + def test_two_things_custom_names(self, x, y): + pass + + instantiate_parametrized_tests(TestParametrized) + + expected_test_names = [ + 'TestParametrized.test_custom_names_bias', + 'TestParametrized.test_custom_names_no_bias', + 'TestParametrized.test_two_things_custom_names_double', + 'TestParametrized.test_two_things_custom_names_quadruple', + 'TestParametrized.test_two_things_custom_names_triple', + ] + test_names = _get_test_names_for_test_class(TestParametrized) + self.assertEqual(expected_test_names, test_names) + + @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3]) + def test_subtest_expected_failure(self, x): + if x == 2: + raise RuntimeError('Boom') + + @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3]) + @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])]) + def test_two_things_subtest_expected_failure(self, x, y): + if x == 1 or y == 6: + raise RuntimeError('Boom') + + +class TestTestParametrizationDeviceType(TestCase): + def test_unparametrized_names(self, device): + # This test exists to protect against regressions in device / dtype test naming + # due to parametrization logic. + + device = self.device_type + + class TestParametrized(TestCase): + def test_device_specific(self, device): + pass + + @dtypes(torch.float32, torch.float64) + def test_device_dtype_specific(self, device, dtype): + pass + + instantiate_device_type_tests(TestParametrized, locals(), only_for=device) + + device_cls = locals()['TestParametrized{}'.format(device.upper())] + expected_test_names = [name.format(device_cls.__name__, device) for name in ( + '{}.test_device_dtype_specific_{}_float32', + '{}.test_device_dtype_specific_{}_float64', + '{}.test_device_specific_{}') + ] + test_names = _get_test_names_for_test_class(device_cls) + self.assertEqual(expected_test_names, test_names) + + def test_default_names(self, device): + device = self.device_type + + class TestParametrized(TestCase): + @parametrize("x", range(5)) + def test_default_names(self, device, x): + pass + + @parametrize("x,y", [(1, 2), (2, 3), (3, 4)]) + def test_two_things_default_names(self, device, x, y): + pass + + + instantiate_device_type_tests(TestParametrized, locals(), only_for=device) + + device_cls = locals()['TestParametrized{}'.format(device.upper())] + expected_test_names = [name.format(device_cls.__name__, device) for name in ( + '{}.test_default_names_x_0_{}', + '{}.test_default_names_x_1_{}', + '{}.test_default_names_x_2_{}', + '{}.test_default_names_x_3_{}', + '{}.test_default_names_x_4_{}', + '{}.test_two_things_default_names_x_1_y_2_{}', + '{}.test_two_things_default_names_x_2_y_3_{}', + '{}.test_two_things_default_names_x_3_y_4_{}') + ] + test_names = _get_test_names_for_test_class(device_cls) + self.assertEqual(expected_test_names, test_names) + + # Note: Currently, the device string is inserted into the name multiple times. + # To fix this, the responsibility for adding the device string can be pushed outside + # into instantiate_device_type_tests(). This will result in the device string always being + # at the end of the test name, which is different from now for @ops tests. This possibly + # breaking change will be made in a future PR. + @unittest.expectedFailure + def test_name_fn(self, device): + device = self.device_type + + class TestParametrized(TestCase): + @parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias') + def test_custom_names(self, device, bias): + pass + + @parametrize("x", [1, 2], name_fn=str) + @parametrize("y", [3, 4], name_fn=str) + @parametrize("z", [5, 6], name_fn=str) + def test_three_things_composition_custom_names(self, device, x, y, z): + pass + + @parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}__{}'.format(x, y)) + def test_two_things_custom_names_alternate(self, device, x, y): + pass + + instantiate_device_type_tests(TestParametrized, locals(), only_for=device) + + device_cls = locals()['TestParametrized{}'.format(device.upper())] + expected_test_names = [name.format(device_cls.__name__, device) for name in ( + '{}.test_custom_names_bias_{}', + '{}.test_custom_names_no_bias_{}', + '{}.test_three_things_composition_custom_names_1_3_5_{}', + '{}.test_three_things_composition_custom_names_1_3_6_{}', + '{}.test_three_things_composition_custom_names_1_4_5_{}', + '{}.test_three_things_composition_custom_names_1_4_6_{}', + '{}.test_three_things_composition_custom_names_2_3_5_{}', + '{}.test_three_things_composition_custom_names_2_3_6_{}', + '{}.test_three_things_composition_custom_names_2_4_5_{}', + '{}.test_three_things_composition_custom_names_2_4_6_{}', + '{}.test_two_things_custom_names_alternate_1__2_{}', + '{}.test_two_things_custom_names_alternate_1__3_{}', + '{}.test_two_things_custom_names_alternate_1__4_{}') + ] + test_names = _get_test_names_for_test_class(device_cls) + self.assertEqual(expected_test_names, test_names) + + def test_subtest_names(self, device): + device = self.device_type + + class TestParametrized(TestCase): + @parametrize("bias", [subtest(True, name='bias'), + subtest(False, name='no_bias')]) + def test_custom_names(self, device, bias): + pass + + @parametrize("x,y", [subtest((1, 2), name='double'), + subtest((1, 3), name='triple'), + subtest((1, 4), name='quadruple')]) + def test_two_things_custom_names(self, device, x, y): + pass + + instantiate_device_type_tests(TestParametrized, locals(), only_for=device) + + device_cls = locals()['TestParametrized{}'.format(device.upper())] + expected_test_names = [name.format(device_cls.__name__, device) for name in ( + '{}.test_custom_names_bias_{}', + '{}.test_custom_names_no_bias_{}', + '{}.test_two_things_custom_names_double_{}', + '{}.test_two_things_custom_names_quadruple_{}', + '{}.test_two_things_custom_names_triple_{}') + ] + test_names = _get_test_names_for_test_class(device_cls) + self.assertEqual(expected_test_names, test_names) + + # Note: Currently, the device string is inserted into the name multiple times. + # To fix this, the responsibility for adding the device string can be pushed outside + # into instantiate_device_type_tests(). This will result in the device string always being + # at the end of the test name, which is different from now for @ops tests. This possibly + # breaking change will be made in a future PR. + @unittest.expectedFailure + def test_ops_composition_names(self, device): + device = self.device_type + + class TestParametrized(TestCase): + @ops(op_db) + @parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled') + def test_op_parametrized(self, device, dtype, op, flag): + pass + + instantiate_device_type_tests(TestParametrized, locals(), only_for=device) + + device_cls = locals()['TestParametrized{}'.format(device.upper())] + expected_test_names = [] + for op in op_db: + for dtype in op.default_test_dtypes(device): + for flag_part in ('_flag_disabled_', '_flag_enabled_'): + op_name = '{}{}'.format(op.name, '_' + op.variant_test_name if op.variant_test_name else '') + part1 = '{}.test_op_parametrized_{}'.format(device_cls.__name__, op_name) + expected_test_names.append(part1 + '_' + dtype_name(dtype) + flag_part + device) + + test_names = _get_test_names_for_test_class(device_cls) + self.assertEqual(sorted(expected_test_names), sorted(test_names)) + + def test_dtypes_composition_names(self, device): + # Test checks that @parametrize and @dtypes compose as expected. + + device = self.device_type + + class TestParametrized(TestCase): + @dtypes(torch.float32, torch.float64) + @parametrize("x", range(3)) + def test_parametrized(self, x, dtype): + pass + + instantiate_device_type_tests(TestParametrized, locals(), only_for=device) + + device_cls = locals()['TestParametrized{}'.format(device.upper())] + expected_test_names = [name.format(device_cls.__name__, device) for name in ( + '{}.test_parametrized_x_0_{}_float32', + '{}.test_parametrized_x_0_{}_float64', + '{}.test_parametrized_x_1_{}_float32', + '{}.test_parametrized_x_1_{}_float64', + '{}.test_parametrized_x_2_{}_float32', + '{}.test_parametrized_x_2_{}_float64') + ] + test_names = _get_test_names_for_test_class(device_cls) + self.assertEqual(sorted(expected_test_names), sorted(test_names)) + + @parametrize("x", [1, subtest(2, decorators=[unittest.expectedFailure]), 3]) + def test_subtest_expected_failure(self, device, x): + if x == 2: + raise RuntimeError('Boom') + + @parametrize("x", [subtest(1, decorators=[unittest.expectedFailure]), 2, 3]) + @parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])]) + def test_two_things_subtest_expected_failure(self, device, x, y): + if x == 1 or y == 6: + raise RuntimeError('Boom') + + +instantiate_parametrized_tests(TestTestParametrization) +instantiate_device_type_tests(TestTestParametrizationDeviceType, globals()) + + if __name__ == '__main__': run_tests() diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 971b3a6..ee3c7ff 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -13,7 +13,7 @@ import torch from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \ skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \ IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, DeterministicGuard, TEST_SKIP_NOARCH, \ - TEST_WITH_MIOPEN_SUGGEST_NHWC + _TestParametrizer, dtype_name, TEST_WITH_MIOPEN_SUGGEST_NHWC from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing._internal.common_dtype import get_all_dtypes @@ -252,19 +252,14 @@ except ImportError: # then inherit from it for your generic test. -def _dtype_name(dtype): - """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """ - return str(dtype).split('.')[1] - - def _dtype_test_suffix(dtypes): """ Returns the test suffix for a dtype, sequence of dtypes, or None. """ if isinstance(dtypes, list) or isinstance(dtypes, tuple): if len(dtypes) == 0: return '' - return '_' + '_'.join((_dtype_name(d) for d in dtypes)) + return '_' + '_'.join((dtype_name(d) for d in dtypes)) elif dtypes: - return '_{}'.format(_dtype_name(dtypes)) + return '_{}'.format(dtype_name(dtypes)) else: return '' @@ -382,22 +377,32 @@ class DeviceTypeTestBase(TestCase): return result - assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name) - setattr(cls, test_name, instantiated_test) + assert not hasattr(cls, name), "Redefinition of test {0}".format(name) + setattr(cls, name, instantiated_test) # Handles tests that need parametrization (e.g. those that run across a set of # ops / modules using the @ops or @modules decorators). - if hasattr(test, 'parametrize_fn'): - for (test, test_name, param_kwargs) in test.parametrize_fn(test, generic_cls, cls): - instantiate_test_helper(cls=cls, name=test_name, test=test, param_kwargs=param_kwargs) - else: - dtypes = cls._get_dtypes(test) - dtypes = tuple(dtypes) if dtypes is not None else (None,) - for dtype in dtypes: - param_kwargs = {} - _update_param_kwargs(param_kwargs, 'dtype', dtype) - test_name = '{}_{}{}'.format(name, cls.device_type, _dtype_test_suffix(dtype)) - instantiate_test_helper(cls=cls, name=test_name, test=test, param_kwargs=param_kwargs) + + def default_parametrize_fn(test, generic_cls, cls): + # By default, parametrize only over device. + test_suffix = cls.device_type + yield (test, test_suffix, {}) + + parametrize_fn = test.parametrize_fn if hasattr(test, 'parametrize_fn') else default_parametrize_fn + for (test, test_suffix, param_kwargs) in parametrize_fn(test, generic_cls, cls): + if hasattr(test, 'handles_dtypes') and test.handles_dtypes: + full_name = '{}_{}'.format(name, test_suffix) + instantiate_test_helper(cls=cls, name=full_name, test=test, param_kwargs=param_kwargs) + else: + # The parametrize_fn doesn't handle dtypes internally; handle them here instead by generating + # a test per dtype. + dtypes = cls._get_dtypes(test) + dtypes = tuple(dtypes) if dtypes is not None else (None,) + for dtype in dtypes: + all_param_kwargs = dict(param_kwargs) + _update_param_kwargs(all_param_kwargs, 'dtype', dtype) + full_name = '{}_{}{}'.format(name, test_suffix, _dtype_test_suffix(dtype)) + instantiate_test_helper(cls=cls, name=full_name, test=test, param_kwargs=all_param_kwargs) def run(self, result=None): super().run(result=result) @@ -634,43 +639,6 @@ class OpDTypes(Enum): none = 5 # Instantiate no dtype variants (no dtype kwarg needed) -class _TestParametrizer(object): - """ - Decorator class for parametrizing a test function, yielding a set of new tests spawned - from the original generic test, each specialized for a specific set of test inputs. For - example, parametrizing a test across the set of ops will result in a test function per op. - - The decision of how to parametrize / what to parametrize over is intended to be implemented - by each derived class. - - In the details, the decorator adds a 'parametrize_fn' property to the test function that is called - during device-specific test instantiation performed in instantiate_device_type_tests(). Because of this, - there is no need to parametrize over device type, as that is already handled separately. - """ - def _parametrize_test(self, test, generic_cls, device_cls): - """ - Parametrizes the given test function across whatever dimension is specified by the derived class. - Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all - ops, all modules, or all ops + their associated dtypes. - - Args: - test (fn): Test function to parametrize over; must support least a device arg - generic_cls (class): Generic test class object containing tests (e.g. TestFoo) - device_cls (class): Device-specialized test class object (e.g. TestFooCPU) - - Returns: - Generator object returning 3-tuples of: - test (fn): Parametrized test function; must support a device arg and args for any params - test_name (str): Parametrized name of the test (e.g. test_bar_opname_int64) - param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64}) - """ - raise NotImplementedError - - def __call__(self, fn): - fn.parametrize_fn = self._parametrize_test - return fn - - # Decorator that defines the OpInfos a test template should be instantiated for. # # Example usage: @@ -712,6 +680,7 @@ class _TestParametrizer(object): class ops(_TestParametrizer): def __init__(self, op_list, *, dtypes: OpDTypes = OpDTypes.basic, allowed_dtypes: Optional[Sequence[torch.dtype]] = None): + super().__init__(handles_dtypes=True) self.op_list = op_list self.opinfo_dtypes = dtypes self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None @@ -745,11 +714,10 @@ class ops(_TestParametrizer): for dtype in dtypes: # Construct the test name. - test_name = '{}_{}{}_{}{}'.format(test.__name__, - op.name.replace('.', '_'), - '_' + op.variant_test_name if op.variant_test_name else '', - device_cls.device_type, - _dtype_test_suffix(dtype)) + test_name = '{}{}_{}{}'.format(op.name.replace('.', '_'), + '_' + op.variant_test_name if op.variant_test_name else '', + device_cls.device_type, + _dtype_test_suffix(dtype)) # Construct parameter kwargs to pass to the test. param_kwargs = {'op': op} diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5dd1cb2..5aa8b67 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -758,6 +758,12 @@ class OpInfo(object): return (supported if self._default_test_dtypes is None else supported.intersection(self._default_test_dtypes)) + @property + def formatted_name(self): + """Returns a formatted full name for this OpInfo that can be used in test names.""" + variant = '_' + self.variant_test_name if self.variant_test_name else '' + return '{}{}'.format(self.name.replace('.', '_'), variant) + def _generate_reduction_inputs(device, dtype, requires_grad): """Generates input tensors for testing reduction operators""" diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index b1cbbb3..a7133b7 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -49,6 +49,7 @@ class modules(_TestParametrizer): """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """ def __init__(self, module_info_list): + super().__init__(handles_dtypes=True) self.module_info_list = module_info_list def _parametrize_test(self, test, generic_cls, device_cls): @@ -56,10 +57,9 @@ class modules(_TestParametrizer): # TODO: Factor some of this out since it's similar to OpInfo. for dtype in floating_types(): # Construct the test name. - test_name = '{}_{}_{}{}'.format(test.__name__, - module_info.name.replace('.', '_'), - device_cls.device_type, - _dtype_test_suffix(dtype)) + test_name = '{}_{}{}'.format(module_info.name.replace('.', '_'), + device_cls.device_type, + _dtype_test_suffix(dtype)) # Construct parameter kwargs to pass to the test. param_kwargs = {'module_info': module_info} @@ -153,6 +153,10 @@ class ModuleInfo(object): def name(self): return formatted_module_name(self.module_cls) + @property + def formatted_name(self): + return self.name.replace('.', '_') + def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 922d5c8..b32c3a1 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -76,6 +76,267 @@ SLOW_TESTS_FILE = '.pytorch-slow-tests.json' slow_tests_dict: Optional[Dict[str, Any]] = None disabled_tests_dict: Optional[Dict[str, Any]] = None + +class _TestParametrizer(object): + """ + Decorator class for parametrizing a test function, yielding a set of new tests spawned + from the original generic test, each specialized for a specific set of test inputs. For + example, parametrizing a test across the set of ops will result in a test function per op. + + The decision of how to parametrize / what to parametrize over is intended to be implemented + by each derived class. + + In the details, the decorator adds a 'parametrize_fn' property to the test function that is called + during device-specific test instantiation performed in instantiate_device_type_tests(). Because of this, + there is no need to parametrize over device type, as that is already handled separately. + + If the decorator is applied to a test function that already has a 'parametrize_fn' property, a new + composite 'parametrize_fn' will be created that generates tests with the product of the parameters + generated by the old and new parametrize_fns. This allows for convenient composability of decorators. + + Args: + handles_dtypes (bool): If True, indicates that it is the responsibility of the decorator to handle + dtypes internally. This allows for more flexibility when needed (e.g. for op-specific dtype handling). + Default: True + """ + def __init__(self, handles_dtypes=True): + self.handles_dtypes = handles_dtypes + + def _parametrize_test(self, test, generic_cls, device_cls): + """ + Parametrizes the given test function across whatever dimension is specified by the derived class. + Tests can be parametrized over any arbitrary dimension or combination of dimensions, such as all + ops, all modules, or all ops + their associated dtypes. + + Args: + test (fn): Test function to parametrize over + generic_cls (class): Generic test class object containing tests (e.g. TestFoo) + device_cls (class): Device-specialized test class object (e.g. TestFooCPU); set to None + if the tests are not part of a device-specific set + + Returns: + Generator object returning 3-tuples of: + test (fn): Parametrized test function; must support a device arg and args for any params + test_name (str): Parametrized suffix for the test (e.g. opname_int64); will be appended to + the base name of the test + param_kwargs (dict): Param kwargs to pass to the test (e.g. {'op': 'add', 'dtype': torch.int64}) + """ + raise NotImplementedError + + def __call__(self, fn): + if hasattr(fn, 'parametrize_fn'): + # Do composition with the product of args. + old_parametrize_fn = fn.parametrize_fn + new_parametrize_fn = self._parametrize_test + + def composite_fn(test, generic_cls, device_cls, + old_parametrize_fn=old_parametrize_fn, + new_parametrize_fn=new_parametrize_fn): + old_tests = [(test, test_name, param_kwargs) for (test, test_name, param_kwargs) in + old_parametrize_fn(test, generic_cls, device_cls)] + for (old_test, old_test_name, old_param_kwargs) in old_tests: + for (new_test, new_test_name, new_param_kwargs) in \ + new_parametrize_fn(old_test, generic_cls, device_cls): + full_param_kwargs = {**old_param_kwargs, **new_param_kwargs} + yield (new_test, '{}_{}'.format(new_test_name, old_test_name), full_param_kwargs) + + fn.parametrize_fn = composite_fn + old_handles_dtypes = fn.handles_dtypes if hasattr(fn, 'handles_dtypes') else False + if self.handles_dtypes and old_handles_dtypes: + raise RuntimeError('Cannot compose multiple parametrization decorators that handle dtypes; ' + 'their dtype handling conflicts') + fn.handles_dtypes = self.handles_dtypes or old_handles_dtypes + else: + fn.parametrize_fn = self._parametrize_test + fn.handles_dtypes = self.handles_dtypes + return fn + + +def instantiate_parametrized_tests(generic_cls): + """ + Instantiates tests that have been decorated with a parametrize_fn. This is generally performed by a + decorator subclass of _TestParametrizer. The generic test will be replaced on the test class by + parametrized tests with specialized names. + + Args: + generic_cls (class): Generic test class object containing tests (e.g. TestFoo) + """ + for attr_name in tuple(dir(generic_cls)): + class_attr = getattr(generic_cls, attr_name) + if not hasattr(class_attr, 'parametrize_fn'): + continue + + if hasattr(class_attr, 'handles_dtypes') and class_attr.handles_dtypes: + raise RuntimeError('instantiate_parametrized_tests() should not be used with decorators ' + 'that handle dtypes internally (e.g. @ops, @modules, etc.). Use ' + 'instantiate_device_type_tests() with these instead.') + + # Remove the generic test from the test class. + delattr(generic_cls, attr_name) + + # Add parametrized tests to the test class. + def instantiate_test_helper(cls, name, test, param_kwargs): + @wraps(test) + def instantiated_test(self, param_kwargs=param_kwargs): + test(self, **param_kwargs) + + assert not hasattr(generic_cls, name), "Redefinition of test {0}".format(name) + setattr(generic_cls, name, instantiated_test) + + for (test, test_suffix, param_kwargs) in class_attr.parametrize_fn( + class_attr, generic_cls=generic_cls, device_cls=None): + full_name = '{}_{}'.format(test.__name__, test_suffix) + instantiate_test_helper(cls=generic_cls, name=full_name, test=test, param_kwargs=param_kwargs) + + +class subtest(object): + """ + Explicit subtest case for use with test parametrization. + Allows for explicit naming of individual subtest cases as well as applying + decorators to the parametrized test. + + Args: + arg_values (iterable): Iterable of arg values (e.g. range(10)) or + tuples of arg values (e.g. [(1, 2), (3, 4)]). + name (str): Optional name to use for the test. + decorators (iterable): Iterable of decorators to apply to the generated test. + """ + __slots__ = ['arg_values', 'name', 'decorators'] + + def __init__(self, arg_values, name=None, decorators=None): + self.arg_values = arg_values + self.name = name + self.decorators = decorators if decorators else [] + + +class parametrize(_TestParametrizer): + """ + Decorator for applying generic test parametrizations. + + The interface for this decorator is modeled after `@pytest.mark.parametrize`. + Basic usage between this decorator and pytest's is identical. The first argument + should be a string containing comma-separated names of parameters for the test, and + the second argument should be an iterable returning values or tuples of values for + the case of multiple parameters. + + Beyond this basic usage, the decorator provides some additional functionality that + pytest does not. + + 1. Parametrized tests end up as generated test functions on unittest test classes. + Since this differs from how pytest works, this decorator takes on the additional + responsibility of naming these test functions. The default test names consists of + the test's base name followed by each parameter name + value (e.g. "test_bar_x_1_y_foo"), + but custom names can be defined using `name_fn` or the `subtest` structure (see below). + + 2. The decorator specially handles parameter values of type `subtest`, which allows for + more fine-grained control over both test naming and test execution. In particular, it can + be used to tag subtests with explicit test names or apply arbitrary decorators (see examples + below). + + Examples:: + + @parametrize("x", range(5)) + def test_foo(self, x): + ... + + @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')]) + def test_bar(self, x, y): + ... + + @parametrize("x,y", [(1, 'foo'), (2, 'bar'), (3, 'baz')], + name_fn=lambda x, y: '{}_{}'.format(x, y)) + def test_bar_custom_names(self, x, y): + ... + + @parametrize("x, y", [subtest((1, 2), name='double'), + subtest((1, 3), name='triple', decorators=[unittest.expectedFailure]), + subtest((1, 4), name='quadruple')]) + def test_baz(self, x, y): + ... + + Args: + arg_str (str): String of arg names separate by commas (e.g. "x,y"). + arg_values (iterable): Iterable of arg values (e.g. range(10)) or + tuples of arg values (e.g. [(1, 2), (3, 4)]). + name_fn (callable): Optional function that takes in parameters and returns subtest name. + """ + def __init__(self, arg_str, arg_values, name_fn=None): + super().__init__(handles_dtypes=False) + self.arg_names = arg_str.split(',') + self.arg_values = arg_values + self.name_fn = name_fn + + def _formatted_str_repr(self, name, value): + """ Returns a string representation for the given arg that is suitable for use in test function names. """ + if isinstance(value, torch.dtype): + return dtype_name(value) + elif isinstance(value, torch.device): + return str(value) + # Can't use isinstance as it would cause a circular import + elif value.__class__.__name__ == 'OpInfo' or value.__class__.__name__ == 'ModuleInfo': + return value.formatted_name + else: + # Include name and value separated by underscore. + return '{}_{}'.format(name, str(value).replace('.', '_')) + + def _default_subtest_name(self, values): + return '_'.join([self._formatted_str_repr(a, v) for a, v in zip(self.arg_names, values)]) + + def _get_subtest_name(self, values, explicit_name=None): + if explicit_name: + subtest_name = explicit_name + elif self.name_fn: + subtest_name = self.name_fn(*values) + else: + subtest_name = self._default_subtest_name(values) + return subtest_name + + def _parametrize_test(self, test, generic_cls, device_cls): + if len(self.arg_names) == 0: + # No additional parameters needed for the test. + test_name = device_cls.device_type if device_cls else '' + yield (test, test_name, {}) + else: + # Each "values" item is expected to be either: + # * A tuple of values with one for each arg. For a single arg, a single item is expected. + # * A subtest instance with arg_values matching the previous. + for values in self.arg_values: + maybe_name = None + if isinstance(values, subtest): + sub = values + values = sub.arg_values + maybe_name = sub.name + + # Apply decorators. + @wraps(test) + def test_wrapper(*args, **kwargs): + return test(*args, **kwargs) + + for decorator in sub.decorators: + test_wrapper = decorator(test_wrapper) + + gen_test = test_wrapper + else: + gen_test = test + + values = list(values) if len(self.arg_names) > 1 else [values] + if len(values) != len(self.arg_names): + raise RuntimeError('Expected # values == # arg names, but got: {} ' + 'values and {} names for test "{}"'.format( + len(values), len(self.arg_names), test.__name__)) + + param_kwargs = { + name: value for name, value in zip(self.arg_names, values) + } + + subtest_name = self._get_subtest_name(values, explicit_name=maybe_name) + test_name = '{}{}'.format(subtest_name, '_' + device_cls.device_type if device_cls else '') + if '.' in test_name: + raise RuntimeError('Test name cannot contain periods, but got: {}'.format(test_name)) + + yield (gen_test, test_name, param_kwargs) + + class ProfilingMode(Enum): LEGACY = 1 SIMPLE = 2 @@ -271,6 +532,12 @@ def discover_test_cases_recursively(suite_or_case): def get_test_names(test_cases): return ['.'.join(case.id().split('.')[-2:]) for case in test_cases] +def _print_test_names(): + suite = unittest.TestLoader().loadTestsFromModule(__main__) + test_cases = discover_test_cases_recursively(suite) + for name in get_test_names(test_cases): + print(name) + def chunk_list(lst, nchunks): return [lst[i::nchunks] for i in range(nchunks)] @@ -300,10 +567,7 @@ def run_tests(argv=UNITTEST_ARGS): print(f'[WARNING] disabled test file provided but not found: {IMPORT_DISABLED_TESTS}') # Determine the test launch mechanism if TEST_DISCOVER: - suite = unittest.TestLoader().loadTestsFromModule(__main__) - test_cases = discover_test_cases_recursively(suite) - for name in get_test_names(test_cases): - print(name) + _print_test_names() elif TEST_IN_SUBPROCESS: suite = unittest.TestLoader().loadTestsFromModule(__main__) test_cases = discover_test_cases_recursively(suite) @@ -2585,3 +2849,7 @@ def sandcastle_skip_if(condition, reason): return wrapper return decorator + +def dtype_name(dtype): + """ Returns the pretty name of the dtype (e.g. torch.int64 -> int64). """ + return str(dtype).split('.')[1] -- 2.7.4