[OpInfo] Add expected_failure kwarg to SkipInfo (#62963)
authorHeitor Schueroff <heitorschueroff@fb.com>
Mon, 16 Aug 2021 01:06:41 +0000 (18:06 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 01:09:20 +0000 (18:09 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62963

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D30327199

Pulled By: heitorschueroff

fbshipit-source-id: 45231eca11d1697a4449d79849fb17264d128a6b

torch/testing/_internal/common_methods_invocations.py

index 381a568..1c22ffa 100644 (file)
@@ -21,7 +21,7 @@ from torch.testing import \
      integral_types_and, all_types, double_types)
 from .._core import _dispatch_dtypes
 from torch.testing._internal.common_device_type import \
-    (onlyOnCPUAndCUDA, skipIf, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
+    (expectedFailure, onlyOnCPUAndCUDA, skipIf, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfNoCusolver,
      skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride, toleranceOverride, tol)
 from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater, SM60OrLater
 from torch.testing._internal.common_utils import \
@@ -74,11 +74,22 @@ class SkipInfo(DecorateInfo):
        an operator. Any test that matches all provided arguments will be skipped.
        The skip will only be checked if the active_if argument is True."""
 
-    def __init__(self, cls_name=None, test_name=None, *,
-                 device_type=None, dtypes=None, active_if=True):
-        super().__init__(decorators=skipIf(True, "Skipped!"), cls_name=cls_name,
-                         test_name=test_name, device_type=device_type, dtypes=dtypes,
-                         active_if=active_if)
+    def __init__(
+            self, cls_name=None, test_name=None, *, device_type=None, dtypes=None, active_if=True,
+            expected_failure=False):
+        """
+        Args:
+            cls_name: the name of the test class to skip
+            test_name: the name of the test within the test class to skip
+            device_type: the devices for which to skip the tests
+            dtypes: the dtypes for which to skip the tests
+            active_if: whether tests matching the above arguments should be skipped
+            expected_failure: whether to assert that skipped tests fail
+        """
+        decorator = expectedFailure(device_type) if expected_failure else skipIf(True, "Skipped!")
+        super().__init__(decorators=decorator, cls_name=cls_name, test_name=test_name,
+                         device_type=device_type, dtypes=dtypes, active_if=active_if)
+
 
 class SampleInput(object):
     """Represents sample inputs to a function."""