From: Heitor Schueroff Date: Mon, 16 Aug 2021 01:06:41 +0000 (-0700) Subject: Small refactor for OpInfo decorators (#62713) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~1008 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8987726cc643044e1d554d5bac818cef0596306f;p=platform%2Fupstream%2Fpytorch.git Small refactor for OpInfo decorators (#62713) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62713 Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D30327200 Pulled By: heitorschueroff fbshipit-source-id: 1899293990c8c0a66da88646714b38f1aae9179d --- diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index da80905..8ec6e71 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -761,25 +761,12 @@ class ops(_TestParametrizer): # Test-specific decorators are applied to the original test, # however. try: - active_decorators = [] - if op.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype): - active_decorators.append(skipIf(True, "Skipped!")) - - if op.decorators is not None: - for decorator in op.decorators: - # Can't use isinstance as it would cause a circular import - if decorator.__class__.__name__ == 'DecorateInfo': - if decorator.is_active(generic_cls.__name__, test.__name__, - device_cls.device_type, dtype): - active_decorators += decorator.decorators - else: - active_decorators.append(decorator) - @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) - for decorator in active_decorators: + for decorator in op.get_decorators( + generic_cls.__name__, test.__name__, device_cls.device_type, dtype): test_wrapper = decorator(test_wrapper) yield (test_wrapper, test_name, param_kwargs) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 14fd5d1..381a568 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -458,7 +458,7 @@ class OpInfo(object): # modifying tests and a pointer to the op's sample inputs function # this function lets the OpInfo generate valid inputs skips=tuple(), # information about which tests to skip - decorators=None, # decorators to apply to generated tests + decorators=tuple(), # decorators to apply to generated tests sample_inputs_func=None, # function to generate sample inputs # the following metadata relates to dtype support and is tested for correctness in test_ops.py @@ -585,8 +585,7 @@ class OpInfo(object): self.supports_out = supports_out self.safe_casts_outputs = safe_casts_outputs - self.skips = skips - self.decorators = decorators + self.decorators = (*decorators, *skips) self.sample_inputs_func = sample_inputs_func self.assert_autodiffed = assert_autodiffed @@ -700,10 +699,16 @@ class OpInfo(object): return samples - # Returns True if the test should be skipped and False otherwise - def should_skip(self, cls_name, test_name, device_type, dtype): - return any(si.is_active(cls_name, test_name, device_type, dtype) - for si in self.skips) + def get_decorators(self, test_class, test_name, device, dtype): + '''Returns the decorators targeting the given test.''' + result = [] + for decorator in self.decorators: + if isinstance(decorator, DecorateInfo): + if decorator.is_active(test_class, test_name, device, dtype): + result.extend(decorator.decorators) + else: + result.append(decorator) + return result def supported_dtypes(self, device_type): if device_type == 'cpu':