Generic test parametrization functionality (#60753)
authorJoel Schlosser <jbschlosser@fb.com>
Wed, 15 Sep 2021 02:51:32 +0000 (19:51 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 02:52:59 +0000 (19:52 -0700)
commitb7ec7d760d1120683fef1a0ad6c03ddf4b8b7d0c
tree5f4cb7e41a44762ed7d63bd3e6eb13191818ca67
parent6ab97fbc287ea89ce81abcb24829b502aeb309cc
Generic test parametrization functionality (#60753)

Summary:
This PR plays around with implementation & usage of a `parametrize` decorator for test parametrization similar to `pytest.mark.parametrize`, based on previous work introducing a `_TestParametrizer` class. It works with the internal `DeviceTest` hierarchy & composes with `dtype`, `skip*`, and other decorators. Basic usage is demonstrated in `test/test_blah.py`:

```python
import unittest
from itertools import product
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests, deviceCountAtLeast, ops)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
    TestCase, run_tests, parametrize, instantiate_parametrized_tests, subtest)

class TestBlah(TestCase):
    parametrize("x", range(5))
    def test_default_names(self, x):
        print('Passed in:', x)

    # Use default names but add an expected failure.
    parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]),
                       *range(1, 5)])
    def test_default_names_expected_failure(self, x):
        if x == 0:
            raise RuntimeError('Boom')
        print('Passed in:', x)

    parametrize("bias", [False, True], name_fn=lambda b: 'bias' if b else 'no_bias')
    def test_custom_names(self, bias):
        print('Passed in:', bias)

    parametrize("bias", [subtest(True, name='bias'),
                          subtest(False, name='no_bias')])
    def test_custom_names_alternate(self, bias):
        print('Passed in:', bias)

    parametrize("x,y", [(1, 2), (1, 3), (1, 4)])
    def test_two_things_default_names(self, x, y):
        print('Passed in:', x, y)

    parametrize("x", [1, 2, 3])
    parametrize("y", [4, 5, 6])
    def test_two_things_composition(self, x, y):
        print('Passed in:', x, y)

    parametrize("x", [subtest(0, decorators=[unittest.expectedFailure]),
                       *range(1, 3)])
    parametrize("y", [4, 5, subtest(6, decorators=[unittest.expectedFailure])])
    def test_two_things_composition_expected_failure(self, x, y):
        if x == 0 or y == 6:
            raise RuntimeError('Boom')
        print('Passed in:', x, y)

    parametrize("x", [1, 2])
    parametrize("y", [3, 4])
    parametrize("z", [5, 6])
    def test_three_things_composition(self, x, y, z):
        print('Passed in:', x, y, z)

    parametrize("x", [1, 2], name_fn=str)
    parametrize("y", [3, 4], name_fn=str)
    parametrize("z", [5, 6], name_fn=str)
    def test_three_things_composition_custom_names(self, x, y, z):
        print('Passed in:', x, y, z)

    parametrize("x,y", product(range(2), range(3)))
    def test_two_things_product(self, x, y):
        print('Passed in:', x, y)

    parametrize("x,y", [subtest((1, 2), name='double'),
                         subtest((1, 3), name='triple'),
                         subtest((1, 4), name='quadruple')])
    def test_two_things_custom_names(self, x, y):
        print('Passed in:', x, y)

    parametrize("x,y", [(1, 2), (1, 3), (1, 4)], name_fn=lambda x, y: '{}_{}'.format(x, y))
    def test_two_things_custom_names_alternate(self, x, y):
        print('Passed in:', x, y)

class TestDeviceBlah(TestCase):
    parametrize("x", range(10))
    def test_default_names(self, device, x):
        print('Passed in:', device, x)

    parametrize("x,y", [(1, 2), (3, 4), (5, 6)])
    def test_two_things(self, device, x, y):
        print('Passed in:', device, x, y)

    deviceCountAtLeast(1)
    def test_multiple_devices(self, devices):
        print('Passed in:', devices)

    ops(op_db)
    parametrize("flag", [False, True], lambda f: 'flag_enabled' if f else 'flag_disabled')
    def test_op_parametrized(self, device, dtype, op, flag):
        print('Passed in:', device, dtype, op, flag)

instantiate_parametrized_tests(TestBlah)
instantiate_device_type_tests(TestDeviceBlah, globals())

if __name__ == '__main__':
    run_tests()
```

Generated tests:
```
TestBlah.test_custom_names_alternate_bias
TestBlah.test_custom_names_alternate_no_bias
TestBlah.test_custom_names_bias
TestBlah.test_custom_names_no_bias
TestBlah.test_default_names_expected_failure_x_0
TestBlah.test_default_names_expected_failure_x_1
TestBlah.test_default_names_expected_failure_x_2
TestBlah.test_default_names_expected_failure_x_3
TestBlah.test_default_names_expected_failure_x_4
TestBlah.test_default_names_x_0
TestBlah.test_default_names_x_1
TestBlah.test_default_names_x_2
TestBlah.test_default_names_x_3
TestBlah.test_default_names_x_4
TestBlah.test_three_things_composition_custom_names_1_3_5
TestBlah.test_three_things_composition_custom_names_1_3_6
TestBlah.test_three_things_composition_custom_names_1_4_5
TestBlah.test_three_things_composition_custom_names_1_4_6
TestBlah.test_three_things_composition_custom_names_2_3_5
TestBlah.test_three_things_composition_custom_names_2_3_6
TestBlah.test_three_things_composition_custom_names_2_4_5
TestBlah.test_three_things_composition_custom_names_2_4_6
TestBlah.test_three_things_composition_x_1_y_3_z_5
TestBlah.test_three_things_composition_x_1_y_3_z_6
TestBlah.test_three_things_composition_x_1_y_4_z_5
TestBlah.test_three_things_composition_x_1_y_4_z_6
TestBlah.test_three_things_composition_x_2_y_3_z_5
TestBlah.test_three_things_composition_x_2_y_3_z_6
TestBlah.test_three_things_composition_x_2_y_4_z_5
TestBlah.test_three_things_composition_x_2_y_4_z_6
TestBlah.test_two_things_composition_expected_failure_x_0_y_4
TestBlah.test_two_things_composition_expected_failure_x_0_y_5
TestBlah.test_two_things_composition_expected_failure_x_0_y_6
TestBlah.test_two_things_composition_expected_failure_x_1_y_4
TestBlah.test_two_things_composition_expected_failure_x_1_y_5
TestBlah.test_two_things_composition_expected_failure_x_1_y_6
TestBlah.test_two_things_composition_expected_failure_x_2_y_4
TestBlah.test_two_things_composition_expected_failure_x_2_y_5
TestBlah.test_two_things_composition_expected_failure_x_2_y_6
TestBlah.test_two_things_composition_x_1_y_4
TestBlah.test_two_things_composition_x_1_y_5
TestBlah.test_two_things_composition_x_1_y_6
TestBlah.test_two_things_composition_x_2_y_4
TestBlah.test_two_things_composition_x_2_y_5
TestBlah.test_two_things_composition_x_2_y_6
TestBlah.test_two_things_composition_x_3_y_4
TestBlah.test_two_things_composition_x_3_y_5
TestBlah.test_two_things_composition_x_3_y_6
TestBlah.test_two_things_custom_names_alternate_1_2
TestBlah.test_two_things_custom_names_alternate_1_3
TestBlah.test_two_things_custom_names_alternate_1_4
TestBlah.test_two_things_custom_names_double
TestBlah.test_two_things_custom_names_quadruple
TestBlah.test_two_things_custom_names_triple
TestBlah.test_two_things_default_names_x_1_y_2
TestBlah.test_two_things_default_names_x_1_y_3
TestBlah.test_two_things_default_names_x_1_y_4
TestBlah.test_two_things_product_x_0_y_0
TestBlah.test_two_things_product_x_0_y_1
TestBlah.test_two_things_product_x_0_y_2
TestBlah.test_two_things_product_x_1_y_0
TestBlah.test_two_things_product_x_1_y_1
TestBlah.test_two_things_product_x_1_y_2
TestDeviceBlahCPU.test_default_names_x_0_cpu
TestDeviceBlahCPU.test_default_names_x_1_cpu
TestDeviceBlahCPU.test_default_names_x_2_cpu
TestDeviceBlahCPU.test_default_names_x_3_cpu
TestDeviceBlahCPU.test_default_names_x_4_cpu
TestDeviceBlahCPU.test_default_names_x_5_cpu
TestDeviceBlahCPU.test_default_names_x_6_cpu
TestDeviceBlahCPU.test_default_names_x_7_cpu
TestDeviceBlahCPU.test_default_names_x_8_cpu
TestDeviceBlahCPU.test_default_names_x_9_cpu
TestDeviceBlahCPU.test_multiple_devices_cpu
TestDeviceBlahCPU.test_op_parametrized_<opname>_<variant>_cpu_uint8_flag_enabled_cpu
TestDeviceBlahCPU.test_two_things_x_1_y_2_cpu
TestDeviceBlahCPU.test_two_things_x_3_y_4_cpu
TestDeviceBlahCPU.test_two_things_x_5_y_6_cpu
TestDeviceBlahMETA.test_default_names_x_0_meta
TestDeviceBlahMETA.test_default_names_x_1_meta
TestDeviceBlahMETA.test_default_names_x_2_meta
TestDeviceBlahMETA.test_default_names_x_3_meta
TestDeviceBlahMETA.test_default_names_x_4_meta
TestDeviceBlahMETA.test_default_names_x_5_meta
TestDeviceBlahMETA.test_default_names_x_6_meta
TestDeviceBlahMETA.test_default_names_x_7_meta
TestDeviceBlahMETA.test_default_names_x_8_meta
TestDeviceBlahMETA.test_default_names_x_9_meta
TestDeviceBlahMETA.test_multiple_devices_meta
TestDeviceBlahMETA.test_op_parametrized_<opname>_<variant>_meta_uint8_flag_enabled_meta
TestDeviceBlahMETA.test_two_things_x_1_y_2_meta
TestDeviceBlahMETA.test_two_things_x_3_y_4_meta
TestDeviceBlahMETA.test_two_things_x_5_y_6_meta
```

Caveats:
* `parametrize` decorators cannot be "stacked" yet; each one overwrites the previous. This will change to either:
  * Allow stacking of multiple decorators
  * Error out with a nice error message if multiple decorators are specified

The PR introduces `instantiate_parametrized_tests()` in addition to `instantiate_device_type_tests()`. The former should be used for non-device-specific tests, and the latter should be used for device-specific tests, as usual. Both of these support the `parametrize` decorator. Only the latter supports the `ops` decorator (no change here- this was already the case).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60753

Reviewed By: saketh-are

Differential Revision: D30606615

Pulled By: jbschlosser

fbshipit-source-id: a34f36d643f68a6e221f419d9bb3e1ae1d84dd65
test/test_testing.py
torch/testing/_internal/common_device_type.py
torch/testing/_internal/common_methods_invocations.py
torch/testing/_internal/common_modules.py
torch/testing/_internal/common_utils.py