import torch
import numpy as np
-import os
import sys
import math
import warnings
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
- torch_to_numpy_dtype_dict, slowTest, TEST_SCIPY, IS_MACOS, IS_PPC,
- IS_WINDOWS)
+ torch_to_numpy_dtype_dict, skipIfTBB, slowTest,
+ TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, deviceCountAtLeast, onlyOnCPUAndCUDA,
onlyCPU, largeTensorTest, precisionOverride, dtypes,
self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), device=device, dtype=torch.float32, out=d))
# TODO: update to work on CUDA, too
- @unittest.skipIf("tbb" in os.getenv("BUILD_ENVIRONMENT", ""),
- "This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571")
+ @skipIfTBB("This test makes TBB sad, see https://github.com/pytorch/pytorch/issues/64571")
@onlyCPU
def test_trilu_indices(self, device):
for test_args in tri_tests_args:
import random
import numbers
import unittest
-import os
import torch
import numpy as np
random_fullrank_matrix_distinct_singular_value,
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY,
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
- GRADCHECK_NONDET_TOL,)
+ GRADCHECK_NONDET_TOL, skipIfTBB)
import torch.testing._internal.opinfo_helper as opinfo_helper
from setuptools import distutils
toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-03)}),
'TestCommon', 'test_reference_testing'
),
- unittest.skipIf("tbb" in os.getenv("BUILD_ENVIRONMENT", ""), "This test makes TBB Sad"),
+ skipIfTBB(),
],
sample_inputs_func=sample_inputs_layer_norm,),
OpInfo('nn.functional.pad',
# Disables tests for when on Github Actions
ON_GHA = os.getenv('GITHUB_ACTIONS', '0') == '1'
+# True if CI is running TBB-enabled Pytorch
+IS_TBB = "tbb" in os.getenv("BUILD_ENVIRONMENT", "")
+
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
numpy_to_torch_dtype_dict = {
np.bool_ : torch.bool,
return wrapper
+def skipIfTBB(message="This test makes TBB sad"):
+ def dec_fn(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ if IS_TBB:
+ raise unittest.SkipTest(message)
+ else:
+ fn(*args, **kwargs)
+ return wrapper
+ return dec_fn
+
+
def slowTest(fn):
@wraps(fn)
def wrapper(*args, **kwargs):