Small refactor for OpInfo decorators (#62713)
authorHeitor Schueroff <heitorschueroff@fb.com>
Mon, 16 Aug 2021 01:06:41 +0000 (18:06 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 01:08:12 +0000 (18:08 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62713

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D30327200

Pulled By: heitorschueroff

fbshipit-source-id: 1899293990c8c0a66da88646714b38f1aae9179d

torch/testing/_internal/common_device_type.py
torch/testing/_internal/common_methods_invocations.py

index da80905..8ec6e71 100644 (file)
@@ -761,25 +761,12 @@ class ops(_TestParametrizer):
                 #   Test-specific decorators are applied to the original test,
                 #   however.
                 try:
-                    active_decorators = []
-                    if op.should_skip(generic_cls.__name__, test.__name__, device_cls.device_type, dtype):
-                        active_decorators.append(skipIf(True, "Skipped!"))
-
-                    if op.decorators is not None:
-                        for decorator in op.decorators:
-                            # Can't use isinstance as it would cause a circular import
-                            if decorator.__class__.__name__ == 'DecorateInfo':
-                                if decorator.is_active(generic_cls.__name__, test.__name__,
-                                                       device_cls.device_type, dtype):
-                                    active_decorators += decorator.decorators
-                            else:
-                                active_decorators.append(decorator)
-
                     @wraps(test)
                     def test_wrapper(*args, **kwargs):
                         return test(*args, **kwargs)
 
-                    for decorator in active_decorators:
+                    for decorator in op.get_decorators(
+                            generic_cls.__name__, test.__name__, device_cls.device_type, dtype):
                         test_wrapper = decorator(test_wrapper)
 
                     yield (test_wrapper, test_name, param_kwargs)
index 14fd5d1..381a568 100644 (file)
@@ -458,7 +458,7 @@ class OpInfo(object):
                  # modifying tests and a pointer to the op's sample inputs function
                  # this function lets the OpInfo generate valid inputs
                  skips=tuple(),  # information about which tests to skip
-                 decorators=None,  # decorators to apply to generated tests
+                 decorators=tuple(),  # decorators to apply to generated tests
                  sample_inputs_func=None,  # function to generate sample inputs
 
                  # the following metadata relates to dtype support and is tested for correctness in test_ops.py
@@ -585,8 +585,7 @@ class OpInfo(object):
         self.supports_out = supports_out
         self.safe_casts_outputs = safe_casts_outputs
 
-        self.skips = skips
-        self.decorators = decorators
+        self.decorators = (*decorators, *skips)
         self.sample_inputs_func = sample_inputs_func
 
         self.assert_autodiffed = assert_autodiffed
@@ -700,10 +699,16 @@ class OpInfo(object):
 
         return samples
 
-    # Returns True if the test should be skipped and False otherwise
-    def should_skip(self, cls_name, test_name, device_type, dtype):
-        return any(si.is_active(cls_name, test_name, device_type, dtype)
-                   for si in self.skips)
+    def get_decorators(self, test_class, test_name, device, dtype):
+        '''Returns the decorators targeting the given test.'''
+        result = []
+        for decorator in self.decorators:
+            if isinstance(decorator, DecorateInfo):
+                if decorator.is_active(test_class, test_name, device, dtype):
+                    result.extend(decorator.decorators)
+            else:
+                result.append(decorator)
+        return result
 
     def supported_dtypes(self, device_type):
         if device_type == 'cpu':