Add `skipIfTBB` decorator (#64942)
authorNikita Shulga <nshulga@fb.com>
Tue, 14 Sep 2021 00:10:30 +0000 (17:10 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 00:11:51 +0000 (17:11 -0700)
Summary:
And replace two existing usages in the codebase with it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64942

Reviewed By: jbschlosser

Differential Revision: D30906382

Pulled By: malfet

fbshipit-source-id: e7f20f53aff734b0379eded361255543dab4fa4b

test/test_tensor_creation_ops.py
torch/testing/_internal/common_methods_invocations.py
torch/testing/_internal/common_utils.py

index 9db8645..2404f02 100644 (file)
@@ -1,7 +1,6 @@
 import torch
 import numpy as np
 
-import os
 import sys
 import math
 import warnings
@@ -12,8 +11,8 @@ import random
 from torch.testing import make_tensor
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
-    torch_to_numpy_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC,
-    IS_WINDOWS)
+    torch_to_numpy_dtype_dict, skipIfTBB, slowTest,
+    TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS)
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
     onlyCPU, largeTensorTest, precisionOverride, dtypes,
@@ -1191,8 +1190,7 @@ class TestTensorCreation(TestCase):
         self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))
 
     # TODO: update to work on CUDA, too
-    @unittest.skipIf("tbb" in os.getenv("BUILD_ENVIRONMENT", ""),
-                     "This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571")
+    @skipIfTBB("This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571")
     @onlyCPU
     def test_trilu_indices(self, device):
         for test_args in tri_tests_args:
index 5949eb8..a2b9fea 100644 (file)
@@ -7,7 +7,6 @@ import operator
 import random
 import numbers
 import unittest
-import os
 
 import torch
 import numpy as np
@@ -35,7 +34,7 @@ from torch.testing._internal.common_utils import \
      random_fullrank_matrix_distinct_singular_value,
      TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
      torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
-     GRADCHECK_NONDET_TOL,)
+     GRADCHECK_NONDET_TOL, skipIfTBB)
 import torch.testing._internal.opinfo_helper as opinfo_helper
 
 from setuptools import distutils
@@ -7479,7 +7478,7 @@ op_db: List[OpInfo] = [
                    toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
                    'TestCommon', 'test_reference_testing'
                ),
-               unittest.skipIf("tbb" in os.getenv("BUILD_ENVIRONMENT", ""), "This test makes TBB Sad"),
+               skipIfTBB(),
            ],
            sample_inputs_func=sample_inputs_layer_norm,),
     OpInfo('nn.functional.pad',
index 11364c3..922d5c8 100644 (file)
@@ -456,6 +456,9 @@ TEST_SKIP_CUDA_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_SKIP_CUDA_MEM_LEAK_CHECK
 # Disables tests for when on Github Actions
 ON_GHA = os.getenv('GITHUB_ACTIONS', '0') == '1'
 
+# True if CI is running TBB-enabled Pytorch
+IS_TBB = "tbb" in os.getenv("BUILD_ENVIRONMENT", "")
+
 # Dict of NumPy dtype -> torch dtype (when the correspondence exists)
 numpy_to_torch_dtype_dict = {
     np.bool_      : torch.bool,
@@ -692,6 +695,18 @@ def skipIfOnGHA(fn):
     return wrapper
 
 
+def skipIfTBB(message="This test makes TBB sad"):
+    def dec_fn(fn):
+        @wraps(fn)
+        def wrapper(*args, **kwargs):
+            if IS_TBB:
+                raise unittest.SkipTest(message)
+            else:
+                fn(*args, **kwargs)
+        return wrapper
+    return dec_fn
+
+
 def slowTest(fn):
     @wraps(fn)
     def wrapper(*args, **kwargs):