)
-class SkipInfo(DecorateInfo):
- """Describes which test, or type of tests, should be skipped when testing
- an operator. Any test that matches all provided arguments will be skipped.
- The skip will only be checked if the active_if argument is True."""
-
- def __init__(
- self, cls_name=None, test_name=None, *, device_type=None, dtypes=None, active_if=True,
- expected_failure=False):
- """
- Args:
- cls_name: the name of the test class to skip
- test_name: the name of the test within the test class to skip
- device_type: the devices for which to skip the tests
- dtypes: the dtypes for which to skip the tests
- active_if: whether tests matching the above arguments should be skipped
- expected_failure: whether to assert that skipped tests fail
- """
- decorator = unittest.expectedFailure if expected_failure else unittest.skip("Skipped!")
- super().__init__(decorators=decorator, cls_name=cls_name, test_name=test_name,
- device_type=device_type, dtypes=dtypes, active_if=active_if)
-
-
-
class SampleInput(object):
"""Represents sample inputs to a function."""
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex hangs for this function, therefore it raises NotImplementedError for now
- SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
),
OpInfo('linalg.eigvalsh',
aten_name='linalg_eigvalsh',
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck hangs for this function
- SkipInfo('TestGradients', 'test_forward_mode_AD'),),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),),
),
OpInfo('linalg.householder_product',
aten_name='linalg_householder_product',
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck hangs for this function
- SkipInfo('TestGradients', 'test_forward_mode_AD'),),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD'),),
),
OpInfo('eig',
op=torch.eig,