From c4073af61d4e530f8627dced870f42526d93dcaf Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 13 Sep 2021 17:10:30 -0700 Subject: [PATCH] Add `skipIfTBB` decorator (#64942) Summary: And replace two existing usages in the codebase with it Pull Request resolved: https://github.com/pytorch/pytorch/pull/64942 Reviewed By: jbschlosser Differential Revision: D30906382 Pulled By: malfet fbshipit-source-id: e7f20f53aff734b0379eded361255543dab4fa4b --- test/test_tensor_creation_ops.py | 8 +++----- torch/testing/_internal/common_methods_invocations.py | 5 ++--- torch/testing/_internal/common_utils.py | 15 +++++++++++++++ 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 9db8645..2404f02 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -1,7 +1,6 @@ import torch import numpy as np -import os import sys import math import warnings @@ -12,8 +11,8 @@ import random from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, - torch_to_numpy_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC, - IS_WINDOWS) + torch_to_numpy_dtype_dict, skipIfTBB, slowTest, + TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA, onlyCPU, largeTensorTest, precisionOverride, dtypes, @@ -1191,8 +1190,7 @@ class TestTensorCreation(TestCase): self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d)) # TODO: update to work on CUDA, too - @unittest.skipIf("tbb" in os.getenv("BUILD_ENVIRONMENT", ""), - "This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571") + @skipIfTBB("This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571") @onlyCPU def test_trilu_indices(self, device): for test_args in tri_tests_args: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5949eb8..a2b9fea 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7,7 +7,6 @@ import operator import random import numbers import unittest -import os import torch import numpy as np @@ -35,7 +34,7 @@ from torch.testing._internal.common_utils import \ random_fullrank_matrix_distinct_singular_value, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, TEST_WITH_ASAN, - GRADCHECK_NONDET_TOL,) + GRADCHECK_NONDET_TOL, skipIfTBB) import torch.testing._internal.opinfo_helper as opinfo_helper from setuptools import distutils @@ -7479,7 +7478,7 @@ op_db: List[OpInfo] = [ toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}), 'TestCommon', 'test_reference_testing' ), - unittest.skipIf("tbb" in os.getenv("BUILD_ENVIRONMENT", ""), "This test makes TBB Sad"), + skipIfTBB(), ], sample_inputs_func=sample_inputs_layer_norm,), OpInfo('nn.functional.pad', diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 11364c3..922d5c8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -456,6 +456,9 @@ TEST_SKIP_CUDA_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_SKIP_CUDA_MEM_LEAK_CHECK # Disables tests for when on Github Actions ON_GHA = os.getenv('GITHUB_ACTIONS', '0') == '1' +# True if CI is running TBB-enabled Pytorch +IS_TBB = "tbb" in os.getenv("BUILD_ENVIRONMENT", "") + # Dict of NumPy dtype -> torch dtype (when the correspondence exists) numpy_to_torch_dtype_dict = { np.bool_ : torch.bool, @@ -692,6 +695,18 @@ def skipIfOnGHA(fn): return wrapper +def skipIfTBB(message="This test makes TBB sad"): + def dec_fn(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if IS_TBB: + raise unittest.SkipTest(message) + else: + fn(*args, **kwargs) + return wrapper + return dec_fn + + def slowTest(fn): @wraps(fn) def wrapper(*args, **kwargs): -- 2.7.4