# Test-specific decorators are applied to the original test,
# however.
try:
- active_decorators = []
- if op.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype):
- active_decorators.append(skipIf(True, "Skipped!"))
-
- if op.decorators is not None:
- for decorator in op.decorators:
- # Can't use isinstance as it would cause a circular import
- if decorator.__class__.__name__ == 'DecorateInfo':
- if decorator.is_active(generic_cls.__name__, test.__name__,
- device_cls.device_type, dtype):
- active_decorators += decorator.decorators
- else:
- active_decorators.append(decorator)
-
@wraps(test)
def test_wrapper(*args, **kwargs):
return test(*args, **kwargs)
- for decorator in active_decorators:
+ for decorator in op.get_decorators(
+ generic_cls.__name__, test.__name__, device_cls.device_type, dtype):
test_wrapper = decorator(test_wrapper)
yield (test_wrapper, test_name, param_kwargs)
# modifying tests and a pointer to the op's sample inputs function
# this function lets the OpInfo generate valid inputs
skips=tuple(), # information about which tests to skip
- decorators=None, # decorators to apply to generated tests
+ decorators=tuple(), # decorators to apply to generated tests
sample_inputs_func=None, # function to generate sample inputs
# the following metadata relates to dtype support and is tested for correctness in test_ops.py
self.supports_out = supports_out
self.safe_casts_outputs = safe_casts_outputs
- self.skips = skips
- self.decorators = decorators
+ self.decorators = (*decorators, *skips)
self.sample_inputs_func = sample_inputs_func
self.assert_autodiffed = assert_autodiffed
return samples
- # Returns True if the test should be skipped and False otherwise
- def should_skip(self, cls_name, test_name, device_type, dtype):
- return any(si.is_active(cls_name, test_name, device_type, dtype)
- for si in self.skips)
+ def get_decorators(self, test_class, test_name, device, dtype):
+ '''Returns the decorators targeting the given test.'''
+ result = []
+ for decorator in self.decorators:
+ if isinstance(decorator, DecorateInfo):
+ if decorator.is_active(test_class, test_name, device, dtype):
+ result.extend(decorator.decorators)
+ else:
+ result.append(decorator)
+ return result
def supported_dtypes(self, device_type):
if device_type == 'cpu':