From eca87f729d071d12ccb31dd2c958a989d8ac17af Mon Sep 17 00:00:00 2001 From: Heitor Schueroff Date: Fri, 27 Aug 2021 10:16:02 -0700 Subject: [PATCH] Added reference tests to ReductionOpInfo (#62900) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62900 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D30408815 Pulled By: heitorschueroff fbshipit-source-id: 6a1f82ac281920ff7405a42f46ccd796e60af9d6 --- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 29 ++-- test/test_reductions.py | 111 ++++++++++++- .../_internal/common_methods_invocations.py | 179 +++++++++++++++------ 3 files changed, 258 insertions(+), 61 deletions(-) diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 89d2fb2..01ed54e 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -163,24 +163,29 @@ static void std_var_kernel_impl(TensorIterator& iter, int64_t correction, bool t } static void prod_kernel_impl(TensorIterator& iter) { - // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] + // Workaround for the error: '*' in boolean context, suggest '&&' instead + // [-Werror=int-in-bool-context] if (iter.dtype() == ScalarType::Bool) { using scalar_t = bool; binary_kernel_reduce_vec( - iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a && b; }, - [=](Vectorized a, Vectorized b) { return a && b; }, - // NOLINTNEXTLINE(bugprone-argument-comment) - /*identity=*/1); - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { - binary_kernel_reduce_vec( iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; }, - [=](Vectorized a, Vectorized b) { return a * b; }, + [=](scalar_t a, scalar_t b) + __ubsan_ignore_undefined__ -> scalar_t { return a && b; }, + [=](Vectorized a, Vectorized b) + __ubsan_ignore_undefined__ { return a && b; }, // NOLINTNEXTLINE(bugprone-argument-comment) /*identity=*/1); - }); + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { + binary_kernel_reduce_vec( + iter, + [=](scalar_t a, scalar_t b) + __ubsan_ignore_undefined__ -> scalar_t { return a * b; }, + [=](Vectorized a, Vectorized b) + __ubsan_ignore_undefined__ { return a * b; }, + // NOLINTNEXTLINE(bugprone-argument-comment) + /*identity=*/1); + }); } } diff --git a/test/test_reductions.py b/test/test_reductions.py index e716336..eed7f73 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -10,7 +10,7 @@ import warnings from torch._six import inf, nan from torch.testing import ( - integral_types_and, floating_and_complex_types_and) + integral_types_and, floating_and_complex_types_and, get_all_dtypes) from torch.testing._internal.common_utils import ( TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, IS_WINDOWS, make_tensor) @@ -296,6 +296,115 @@ class TestReductions(TestCase): result = op(t, *args, dim=dim, **kwargs) self.assertEqual(result.shape, _reduced_shape(t.shape, dim)) + def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs): + """Helper method to test noncontiguous input tensors.""" + assert not t.is_contiguous() + + t_contig = t.contiguous() + for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs): + kwargs.update(reduction_kwargs) + result = op(t, *args, **kwargs) + expected = op(t_contig, *args, **kwargs) + self.assertEqual(result, expected) + + @ops(reduction_ops) + def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo): + """Tests reducing along noncontiguous innermost dimension.""" + t = make_tensor((10, 10), device, dtype) + self._test_noncontiguous(op, t[:, ::2], dim=1) + + @ops(reduction_ops) + def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo): + """Tests reducing along noncontiguous outermost dimension.""" + t = make_tensor((10, 10), device, dtype) + self._test_noncontiguous(op, t[::2, :], dim=0) + + @ops(reduction_ops) + def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo): + """Tests reducing all dimensions of a noncontiguous tensor.""" + t = make_tensor((5, 5, 5), device, dtype) + self._test_noncontiguous(op, t[::2, ::3, 1:-1:2]) + + @ops(reduction_ops) + def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo): + """Tests reducing a transposed tensor.""" + t = make_tensor((5, 5), device, dtype) + self._test_noncontiguous(op, t.T) + + @ops(reduction_ops) + def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo): + """Tests reducing a tensor with expanded singleton dimensions.""" + t = make_tensor((2, 3), device, dtype) + self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1)) + + # NumPy does not support BFloat16 so we don't test that against reference + # implementations. We also don't compare dtypes or test for different + # keepdim because we already have other tests covering those. + # The test_reference_testing in test_ops.py only uses the samples from + # sample_inputs_func which do not test as exhaustively as these tests. + + def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs): + """Compares op against op.ref for the given input and reduction kwargs""" + for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs): + kwargs.update(reduction_kwargs) + result = op(t, *args, **kwargs) + expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs) + self.assertEqual(result, expected, exact_dtype=False) + + @ops(filter(lambda op: op.ref is not None, reduction_ops), + allowed_dtypes=get_all_dtypes(include_bfloat16=False)) + def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo): + """Compares op against reference for scalar input tensors""" + self._test_ref(op, make_tensor([], device, dtype)) + + @ops(filter(lambda op: op.ref is not None, reduction_ops), + allowed_dtypes=get_all_dtypes(include_bfloat16=False)) + def test_ref_small_input(self, device, dtype, op: ReductionOpInfo): + """Compares op against reference for small input tensors""" + t = make_tensor((5, 3, 4, 2), device, dtype, exclude_zero=True) + self._test_ref(op, t) + for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []): + self._test_ref(op, t, dim=dim) + + @ops(filter(lambda op: op.ref is not None, reduction_ops), + allowed_dtypes=[torch.float32]) + def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo): + """Compares op against reference for a large 1D input tensor to check stability""" + self._test_ref(op, make_tensor((2 ** 20,), device, dtype, low=-1, high=2, exclude_zero=True)) + + @ops(filter(lambda op: op.ref is not None, reduction_ops), + allowed_dtypes=[torch.float32]) + def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo): + """Compares op against reference for a large 2D input tensor to test parallelism""" + t = make_tensor((32, 2 ** 16), device, dtype, low=-1, high=2, exclude_zero=True) + self._test_ref(op, t, dim=1) + + @ops(filter(lambda op: op.ref is not None, reduction_ops), + allowed_dtypes=[torch.float32]) + def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo): + """Compares op against reference for a very large input tensor that requires 64 bit indexing""" + self._test_ref(op, make_tensor((275000000,), device, dtype, low=-1, high=2, exclude_zero=True)) + + @ops(filter(lambda op: op.ref is not None, reduction_ops), + allowed_dtypes=get_all_dtypes(include_bfloat16=False)) + def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo): + """Compares op against reference for input tensors with duplicate values""" + t = make_tensor((8, 8), device, dtype, exclude_zero=True) + t[::2, ::2] = t[1::2, 1::2] + self._test_ref(op, t) + self._test_ref(op, t, dim=0) + self._test_ref(op, t, dim=1) + + @ops(filter(lambda op: op.ref is not None, reduction_ops), + allowed_dtypes=[torch.float32, torch.complex64]) + def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo): + """Compares op against reference for input tensors with extremal values""" + t = make_tensor((10,), device, dtype, exclude_zero=True) + extremals = [0, 1] + [nan, inf, -inf] if torch.is_floating_point(t) else [] + for extremal in extremals: + t[5] = extremal + self._test_ref(op, t) + ########################################################################### # TODO: Legacy tests - port to ReductionOpInfo ########################################################################### diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4331c92..2230808 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -187,6 +187,8 @@ class SampleInput(object): return tuple(map(to_numpy, x)) elif isinstance(x, dict): return {k: to_numpy(v) for k, v in x.items()} + elif isinstance(x, torch.dtype): + return torch.empty(0, dtype=x).numpy().dtype elif isinstance(x, (numbers.Number, bool, str)): return x @@ -782,8 +784,8 @@ def _generate_reduction_inputs(device, dtype, requires_grad): """Generates input tensors for testing reduction operators""" yield make_tensor([], device, dtype, requires_grad=requires_grad) yield make_tensor([2], device, dtype, requires_grad=requires_grad) - yield make_tensor([2, 3], device, dtype, requires_grad=requires_grad, noncontiguous=True) - yield make_tensor([3, 2, 1, 5], device, dtype, requires_grad=requires_grad) + yield make_tensor([3, 5], device, dtype, requires_grad=requires_grad, noncontiguous=True) + yield make_tensor([3, 2, 1, 2], device, dtype, requires_grad=requires_grad) def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): @@ -927,6 +929,8 @@ class ReductionOpInfo(OpInfo): # Override OpInfo defaults and call base class __init__ kwargs.setdefault('inplace_variant', None) kwargs.setdefault('sample_inputs_func', sample_inputs_func) + kwargs.setdefault('default_test_dtypes', ( + torch.uint8, torch.int64, torch.float16, torch.bfloat16, torch.float32, torch.complex64)) super(ReductionOpInfo, self).__init__(name, **kwargs) self.identity = identity @@ -4080,38 +4084,6 @@ def sample_inputs_copysign(op_info, device, dtype, requires_grad, **kwargs): return list(generator()) -def sample_inputs_prod(op_info, device, dtype, requires_grad): - def make_arg(shape): - # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck - return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad) - - def prod_single_zero(): - result = make_arg(2 * (S,)) - with torch.no_grad(): - result[0, 1] = 0 - return result - - # will not be needed once OpInfo tests support Iterables - def sample_generator(): - for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad): - yield SampleInput(sample.input) # only Tensor, ignore other inputs - yield sample - sample.kwargs['keepdim'] = True - yield sample - yield SampleInput(prod_single_zero()) - yield SampleInput(make_arg((3, 3, 3)), args=(1,)) - yield SampleInput(make_arg((3, 3, 3)), args=(1,), kwargs={'keepdim': True}) - - # test zero scalar tensor - zero = make_arg(()) - with torch.no_grad(): - zero.zero_() - yield SampleInput(zero) - yield SampleInput(zero, args=(0,)) - yield SampleInput(zero, args=(0,), kwargs={'keepdim': True}) - - return list(sample_generator()) - def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -5521,6 +5493,53 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): return op(input.triu() if upper else input.tril(), upper) +def reference_reduction_numpy(f, supports_keepdims=True): + """Wraps a NumPy reduction operator. + + The wrapper function will forward dim and keepdim kwargs to the wrapped + function as the NumPy equivalent axis and keepdims kwargs. + + Args: + f: NumPy reduction operator to wrap + supports_keepdims (bool, optional): Whether the NumPy operator accepts + keepdims parameter. If it does not, the wrapper will manually unsqueeze + the reduced dimensions if it was called with keepdim=True. Defaults to True. + + Returns: + Wrapped function + """ + @wraps(f) + def wrapper(x: np.ndarray, *args, **kwargs): + # Copy keys into a set + keys = set(kwargs.keys()) + + dim = kwargs.pop('dim', None) + keepdim = kwargs.pop('keepdim', False) + + if 'dim' in keys: + if x.ndim == 0: + # NumPy reductions don't accept dim=0 for scalar inputs + for i in dim if isinstance(dim, tuple) else (dim,): + assert i in {0, -1} + kwargs['axis'] = None + else: + kwargs['axis'] = tuple(dim) if isinstance(dim, Sequence) else dim + + if 'keepdim' in keys and supports_keepdims: + kwargs['keepdims'] = keepdim + + result = f(x, *args, **kwargs) + + # Unsqueeze reduced dimensions if NumPy does not support keepdims + if keepdim and not supports_keepdims and x.ndim > 0: + dim = list(range(x.ndim)) if dim is None else dim + result = np.expand_dims(result, dim) + + return result + + return wrapper + + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -7039,15 +7058,6 @@ op_db: List[OpInfo] = [ supports_out=False, supports_forward_ad=True, sample_inputs_func=sample_inputs_max_min_reduction_no_dim,), - # TODO(@heitorschueroff) Add test for dtype kwarg - OpInfo('mean', - dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), - assert_autodiffed=True, - supports_forward_ad=True, - sample_inputs_func=sample_inputs_reduction, - # Need to skip out test because one of the overload for mean does not support it - # TODO(@heitorschueroff) fix this when implementing ReductionInfo - skips=(SkipInfo('TestCommon', 'test_out'),)), OpInfo('quantile', dtypes=floating_types(), sample_inputs_func=sample_inputs_reduction_quantile), @@ -8890,6 +8900,7 @@ op_db: List[OpInfo] = [ supports_autograd=False, result_dtype=torch.bool, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.all), skips=( # FIXME: does not support passing keepdim without dim SkipInfo('TestReductions', 'test_dim_default_keepdim'), @@ -8897,7 +8908,8 @@ op_db: List[OpInfo] = [ SkipInfo('TestReductions', 'test_dim_none'), SkipInfo('TestReductions', 'test_dim_none_keepdim'), # FIXME: uint8 input returns uint8 instead of bool - SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + SkipInfo('TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), ), ), ReductionOpInfo( @@ -8908,6 +8920,7 @@ op_db: List[OpInfo] = [ supports_autograd=False, result_dtype=torch.bool, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.any), skips=( # FIXME: does not support passing keepdim without dim SkipInfo('TestReductions', 'test_dim_default_keepdim'), @@ -8915,14 +8928,15 @@ op_db: List[OpInfo] = [ SkipInfo('TestReductions', 'test_dim_none'), SkipInfo('TestReductions', 'test_dim_none_keepdim'), # FIXME: uint8 input returns uint8 instead of bool - SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + SkipInfo('TestReductions', 'test_result_dtype', + dtypes=[torch.uint8]), ), ), ReductionOpInfo( 'amax', nan_policy='propagate', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs), + ref=reference_reduction_numpy(np.amax), skips=( # FIXME: sum reduces all dimensions when dim=[] SkipInfo('TestReductions', 'test_dim_empty'), @@ -8933,7 +8947,7 @@ op_db: List[OpInfo] = [ 'amin', nan_policy='propagate', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs), + ref=reference_reduction_numpy(np.amin), skips=( # FIXME: sum reduces all dimensions when dim=[] SkipInfo('TestReductions', 'test_dim_empty'), @@ -8946,6 +8960,7 @@ op_db: List[OpInfo] = [ supports_autograd=False, result_dtype=torch.int64, dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmax, supports_keepdims=False), skips=( # FIXME: keepdim parameter is ignored when dim=None SkipInfo('TestReductions', 'test_dim_default_keepdim'), @@ -8958,6 +8973,7 @@ op_db: List[OpInfo] = [ supports_autograd=False, result_dtype=torch.int64, dtypes=all_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.argmin, supports_keepdims=False), skips=( # FIXME: keepdim parameter is ignored when dim=None SkipInfo('TestReductions', 'test_dim_default_keepdim'), @@ -8972,6 +8988,7 @@ op_db: List[OpInfo] = [ result_dtype=torch.int64, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_reduction_count_nonzero, + ref=reference_reduction_numpy(np.count_nonzero), skips=( # FIXME: count_nonzero does not accept keepdim kwarg SkipInfo('TestReductions', 'test_dim_default_keepdim'), @@ -8986,6 +9003,35 @@ op_db: List[OpInfo] = [ ), ), ReductionOpInfo( + 'mean', + nan_policy='propagate', + supports_out=False, + supports_forward_ad=True, + assert_autodiffed=True, + promotes_int_to_float=True, + dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.mean), + decorators=( + # FIXME: fix precision + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-05, rtol=1e-02), + }), 'TestReductions', 'test_noncontiguous_all'), + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-05, rtol=1e-02), + }), 'TestReductions', 'test_ref_small_input'), + ), + skips=( + # FIXME: prod does not support passing keepdim without passing dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: prod reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + # FIXME: prod does not support passing None to dim + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( 'prod', identity=1, nan_policy='propagate', @@ -8995,7 +9041,7 @@ op_db: List[OpInfo] = [ gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, dtypes=all_types_and_complex_and(torch.bool), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), - sample_inputs_func=sample_inputs_prod, + ref=reference_reduction_numpy(np.prod), skips=( # FIXME: prod does not support passing keepdim without passing dim SkipInfo('TestReductions', 'test_dim_default_keepdim'), @@ -9005,6 +9051,11 @@ op_db: List[OpInfo] = [ # FIXME: prod does not support passing None to dim SkipInfo('TestReductions', 'test_dim_none'), SkipInfo('TestReductions', 'test_dim_none_keepdim'), + # FIXME: improve precision, failing with nan != inf + SkipInfo('TestReductions', 'test_ref_small_input', + dtypes=[torch.float16, torch.complex64]), + SkipInfo('TestReductions', 'test_ref_duplicate_values', + dtypes=[torch.uint8, torch.float16, torch.complex64]), ), ), ReductionOpInfo( @@ -9015,6 +9066,22 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, promotes_int_to_int64=True, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.sum), + decorators=( + # FIXME: fix precision + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-05, rtol=1e-02), + }), 'TestReductions', 'test_noncontiguous_all'), + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1e-02), + }), 'TestReductions', 'test_ref_small_input'), + DecorateInfo(toleranceOverride({ + torch.float32: tol(atol=1e-03, rtol=1e-03), + }), 'TestReductions', 'test_ref_large_input_64bit_indexing'), + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-05, rtol=1e-02), + }), 'TestReductions', 'test_ref_duplicate_values'), + ), skips=( # FIXME: sum does not support passing keepdim without passing dim SkipInfo('TestReductions', 'test_dim_default_keepdim'), @@ -9033,6 +9100,22 @@ op_db: List[OpInfo] = [ supports_out=False, promotes_int_to_int64=True, dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + ref=reference_reduction_numpy(np.nansum), + decorators=( + # FIXME: fix precision + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-05, rtol=1e-02), + }), 'TestReductions', 'test_noncontiguous_all'), + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-03, rtol=1e-02), + }), 'TestReductions', 'test_ref_small_input'), + DecorateInfo(toleranceOverride({ + torch.float32: tol(atol=1e-03, rtol=1e-03), + }), 'TestReductions', 'test_ref_large_input_64bit_indexing'), + DecorateInfo(toleranceOverride({ + torch.float16: tol(atol=1e-05, rtol=1e-02), + }), 'TestReductions', 'test_ref_duplicate_values'), + ), skips=( # FIXME: nansum does not support passing keepdim without passing dim SkipInfo('TestReductions', 'test_dim_default_keepdim'), -- 2.7.4