[Reland] Added reference tests to ReductionOpInfo (#64273)
authorHeitor Schueroff <heitorschueroff@fb.com>
Mon, 13 Sep 2021 03:04:19 +0000 (20:04 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 03:05:43 +0000 (20:05 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64273

Reintroduced sample_inputs_prod and constrained the range of values for large reference tests.

This reverts commit e4fd2ab59ce8645f5ae9477c7724b6af82124b3b.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D30672097

Pulled By: heitorschueroff

fbshipit-source-id: b44ed8dfd5eb0c74c194164dafc3242f6728a78f

aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
test/test_reductions.py
torch/testing/_internal/common_methods_invocations.py

index 89d2fb2..01ed54e 100644 (file)
@@ -163,24 +163,29 @@ static void std_var_kernel_impl(TensorIterator& iter, int64_t correction, bool t
 }
 
 static void prod_kernel_impl(TensorIterator& iter) {
-  // Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context]
+  // Workaround for the error: '*' in boolean context, suggest '&&' instead
+  // [-Werror=int-in-bool-context]
   if (iter.dtype() == ScalarType::Bool) {
     using scalar_t = bool;
     binary_kernel_reduce_vec(
-      iter,
-      [=](scalar_t a, scalar_t b) -> scalar_t { return a && b; },
-      [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return a && b; },
-      // NOLINTNEXTLINE(bugprone-argument-comment)
-      /*identity=*/1);
-  } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] {
-      binary_kernel_reduce_vec(
         iter,
-        [=](scalar_t a, scalar_t b) -> scalar_t { return a * b; },
-        [=](Vectorized <scalar_t> a, Vectorized <scalar_t> b) { return a * b; },
+        [=](scalar_t a, scalar_t b)
+            __ubsan_ignore_undefined__ -> scalar_t { return a && b; },
+        [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
+            __ubsan_ignore_undefined__ { return a && b; },
         // NOLINTNEXTLINE(bugprone-argument-comment)
         /*identity=*/1);
-      });
+  } else {
+    AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] {
+      binary_kernel_reduce_vec(
+          iter,
+          [=](scalar_t a, scalar_t b)
+              __ubsan_ignore_undefined__ -> scalar_t { return a * b; },
+          [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b)
+              __ubsan_ignore_undefined__ { return a * b; },
+          // NOLINTNEXTLINE(bugprone-argument-comment)
+          /*identity=*/1);
+    });
   }
 }
 
index 9760eae..24dada1 100644 (file)
@@ -299,6 +299,115 @@ class TestReductions(TestCase):
             result = op(t, *args, dim=dim, **kwargs)
             self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
 
+    def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
+        """Helper method to test noncontiguous input tensors."""
+        assert not t.is_contiguous()
+
+        t_contig = t.contiguous()
+        for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs):
+            kwargs.update(reduction_kwargs)
+            result = op(t, *args, **kwargs)
+            expected = op(t_contig, *args, **kwargs)
+            self.assertEqual(result, expected)
+
+    @ops(reduction_ops)
+    def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo):
+        """Tests reducing along noncontiguous innermost dimension."""
+        t = make_tensor((10, 10), device, dtype, low=-1, high=1)
+        self._test_noncontiguous(op, t[:, ::2], dim=1)
+
+    @ops(reduction_ops)
+    def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo):
+        """Tests reducing along noncontiguous outermost dimension."""
+        t = make_tensor((10, 10), device, dtype, low=-1, high=1)
+        self._test_noncontiguous(op, t[::2, :], dim=0)
+
+    @ops(reduction_ops)
+    def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo):
+        """Tests reducing all dimensions of a noncontiguous tensor."""
+        t = make_tensor((5, 5, 5), device, dtype, low=-1, high=1)
+        self._test_noncontiguous(op, t[::2, ::3, 1:-1:2])
+
+    @ops(reduction_ops)
+    def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo):
+        """Tests reducing a transposed tensor."""
+        t = make_tensor((5, 5), device, dtype, low=-1, high=1)
+        self._test_noncontiguous(op, t.T)
+
+    @ops(reduction_ops)
+    def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo):
+        """Tests reducing a tensor with expanded singleton dimensions."""
+        t = make_tensor((2, 3), device, dtype, low=-1, high=1)
+        self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1))
+
+    # NumPy does not support BFloat16 so we don't test that against reference
+    # implementations. We also don't compare dtypes or test for different
+    # keepdim because we already have other tests covering those.
+    # The test_reference_testing in test_ops.py only uses the samples from
+    # sample_inputs_func which do not test as exhaustively as these tests.
+
+    def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
+        """Compares op against op.ref for the given input and reduction kwargs"""
+        for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs):
+            kwargs.update(reduction_kwargs)
+            result = op(t, *args, **kwargs)
+            expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs)
+            self.assertEqual(result, expected, exact_dtype=False)
+
+    @ops(filter(lambda op: op.ref is not None, reduction_ops),
+         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
+    def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo):
+        """Compares op against reference for scalar input tensors"""
+        self._test_ref(op, make_tensor([], device, dtype))
+
+    @ops(filter(lambda op: op.ref is not None, reduction_ops),
+         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
+    def test_ref_small_input(self, device, dtype, op: ReductionOpInfo):
+        """Compares op against reference for small input tensors"""
+        t = make_tensor((5, 3, 4, 2), device, dtype, low=-2, high=2, exclude_zero=True)
+        self._test_ref(op, t)
+        for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []):
+            self._test_ref(op, t, dim=dim)
+
+    @ops(filter(lambda op: op.ref is not None, reduction_ops),
+         allowed_dtypes=[torch.float64])
+    def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo):
+        """Compares op against reference for a large 1D input tensor to check stability"""
+        self._test_ref(op, make_tensor((2 ** 20,), device, dtype, low=-1, high=1, exclude_zero=True))
+
+    @ops(filter(lambda op: op.ref is not None, reduction_ops),
+         allowed_dtypes=[torch.float64])
+    def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo):
+        """Compares op against reference for a large 2D input tensor to test parallelism"""
+        t = make_tensor((32, 2 ** 16), device, dtype, low=-1, high=1, exclude_zero=True)
+        self._test_ref(op, t, dim=1)
+
+    @ops(filter(lambda op: op.ref is not None, reduction_ops),
+         allowed_dtypes=[torch.float64])
+    def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo):
+        """Compares op against reference for a very large input tensor that requires 64 bit indexing"""
+        self._test_ref(op, make_tensor((275000000,), device, dtype, low=-1, high=1, exclude_zero=True))
+
+    @ops(filter(lambda op: op.ref is not None, reduction_ops),
+         allowed_dtypes=get_all_dtypes(include_bfloat16=False))
+    def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo):
+        """Compares op against reference for input tensors with duplicate values"""
+        t = make_tensor((4, 4), device, dtype, low=-2, high=2, exclude_zero=True)
+        t[::2, ::2] = t[1::2, 1::2]
+        self._test_ref(op, t)
+        self._test_ref(op, t, dim=0)
+        self._test_ref(op, t, dim=1)
+
+    @ops(filter(lambda op: op.ref is not None, reduction_ops),
+         allowed_dtypes=[torch.float32, torch.complex64])
+    def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo):
+        """Compares op against reference for input tensors with extremal values"""
+        t = make_tensor((5,), device, dtype, exclude_zero=True)
+        extremals = [0, 1, nan, inf, -inf]
+        for extremal in extremals:
+            t[2] = extremal
+            self._test_ref(op, t)
+
     ###########################################################################
     # TODO: Legacy tests - port to ReductionOpInfo
     ###########################################################################
index 8451791..84ff419 100644 (file)
@@ -188,6 +188,8 @@ class SampleInput(object):
                 return tuple(map(to_numpy, x))
             elif isinstance(x, dict):
                 return {k: to_numpy(v) for k, v in x.items()}
+            elif isinstance(x, torch.dtype):
+                return torch_to_numpy_dtype_dict[x]
             elif isinstance(x, (numbers.Number, bool, str)):
                 return x
 
@@ -785,8 +787,8 @@ def _generate_reduction_inputs(device, dtype, requires_grad):
     """Generates input tensors for testing reduction operators"""
     yield make_tensor([], device, dtype, requires_grad=requires_grad)
     yield make_tensor([2], device, dtype, requires_grad=requires_grad)
-    yield make_tensor([2, 3], device, dtype, requires_grad=requires_grad, noncontiguous=True)
-    yield make_tensor([3, 2, 1, 5], device, dtype, requires_grad=requires_grad)
+    yield make_tensor([3, 5], device, dtype, requires_grad=requires_grad, noncontiguous=True)
+    yield make_tensor([3, 2, 1, 2], device, dtype, requires_grad=requires_grad)
 
 
 def _generate_reduction_kwargs(ndim, supports_multiple_dims=True):
@@ -930,6 +932,8 @@ class ReductionOpInfo(OpInfo):
         # Override OpInfo defaults and call base class __init__
         kwargs.setdefault('inplace_variant', None)
         kwargs.setdefault('sample_inputs_func', sample_inputs_func)
+        kwargs.setdefault('default_test_dtypes', (
+            torch.uint8, torch.int64, torch.float16, torch.bfloat16, torch.float32, torch.complex64))
         super(ReductionOpInfo, self).__init__(name, **kwargs)
 
         self.identity = identity
@@ -4305,7 +4309,6 @@ def sample_inputs_prod(op_info, device, dtype, requires_grad):
 
     return list(sample_generator())
 
-
 def sample_inputs_nextafter(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
 
@@ -5761,6 +5764,54 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
     return op(input.triu() if upper else input.tril(), upper)
 
 
+def reference_reduction_numpy(f, supports_keepdims=True):
+    """Wraps a NumPy reduction operator.
+
+    The wrapper function will forward dim and keepdim kwargs to the wrapped
+    function as the NumPy equivalent axis and keepdims kwargs.
+
+    Args:
+        f: NumPy reduction operator to wrap
+        supports_keepdims (bool, optional): Whether the NumPy operator accepts
+            keepdims parameter. If it does not, the wrapper will manually unsqueeze
+            the reduced dimensions if it was called with keepdim=True. Defaults to True.
+
+    Returns:
+        Wrapped function
+    """
+    @wraps(f)
+    def wrapper(x: np.ndarray, *args, **kwargs):
+        # Copy keys into a set
+        keys = set(kwargs.keys())
+
+        dim = kwargs.pop('dim', None)
+        keepdim = kwargs.pop('keepdim', False)
+
+        if 'dim' in keys:
+            dim = tuple(dim) if isinstance(dim, Sequence) else dim
+
+            # NumPy reductions don't accept dim=0 for scalar inputs
+            # so we convert it to None if and only if dim is equivalent
+            if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}:
+                kwargs['axis'] = None
+            else:
+                kwargs['axis'] = dim
+
+        if 'keepdim' in keys and supports_keepdims:
+            kwargs['keepdims'] = keepdim
+
+        result = f(x, *args, **kwargs)
+
+        # Unsqueeze reduced dimensions if NumPy does not support keepdims
+        if keepdim and not supports_keepdims and x.ndim > 0:
+            dim = list(range(x.ndim)) if dim is None else dim
+            result = np.expand_dims(result, dim)
+
+        return result
+
+    return wrapper
+
+
 # Operator database (sorted alphabetically)
 op_db: List[OpInfo] = [
     UnaryUfuncInfo('abs',
@@ -7290,15 +7341,6 @@ op_db: List[OpInfo] = [
            supports_out=False,
            supports_forward_ad=True,
            sample_inputs_func=sample_inputs_max_min_reduction_no_dim,),
-    # TODO(@heitorschueroff) Add test for dtype kwarg
-    OpInfo('mean',
-           dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
-           assert_autodiffed=True,
-           supports_forward_ad=True,
-           sample_inputs_func=sample_inputs_reduction,
-           # Need to skip out test because one of the overload for mean does not support it
-           # TODO(@heitorschueroff) fix this when implementing ReductionInfo
-           skips=(SkipInfo('TestCommon', 'test_out'),)),
     OpInfo('quantile',
            dtypes=floating_types(),
            sample_inputs_func=sample_inputs_reduction_quantile),
@@ -9318,6 +9360,7 @@ op_db: List[OpInfo] = [
         supports_autograd=False,
         result_dtype=torch.bool,
         dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.all),
         skips=(
             # FIXME: does not support passing keepdim without dim
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9325,7 +9368,8 @@ op_db: List[OpInfo] = [
             SkipInfo('TestReductions', 'test_dim_none'),
             SkipInfo('TestReductions', 'test_dim_none_keepdim'),
             # FIXME: uint8 input returns uint8 instead of bool
-            SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
+            SkipInfo('TestReductions', 'test_result_dtype',
+                     dtypes=[torch.uint8]),
         ),
     ),
     ReductionOpInfo(
@@ -9336,6 +9380,7 @@ op_db: List[OpInfo] = [
         supports_autograd=False,
         result_dtype=torch.bool,
         dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.any),
         skips=(
             # FIXME: does not support passing keepdim without dim
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9343,14 +9388,15 @@ op_db: List[OpInfo] = [
             SkipInfo('TestReductions', 'test_dim_none'),
             SkipInfo('TestReductions', 'test_dim_none_keepdim'),
             # FIXME: uint8 input returns uint8 instead of bool
-            SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]),
+            SkipInfo('TestReductions', 'test_result_dtype',
+                     dtypes=[torch.uint8]),
         ),
     ),
     ReductionOpInfo(
         'amax',
         nan_policy='propagate',
         dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-        ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs),
+        ref=reference_reduction_numpy(np.amax),
         skips=(
             # FIXME: sum reduces all dimensions when dim=[]
             SkipInfo('TestReductions', 'test_dim_empty'),
@@ -9361,7 +9407,7 @@ op_db: List[OpInfo] = [
         'amin',
         nan_policy='propagate',
         dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-        ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs),
+        ref=reference_reduction_numpy(np.amin),
         skips=(
             # FIXME: sum reduces all dimensions when dim=[]
             SkipInfo('TestReductions', 'test_dim_empty'),
@@ -9374,6 +9420,7 @@ op_db: List[OpInfo] = [
         supports_autograd=False,
         result_dtype=torch.int64,
         dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmax, supports_keepdims=False),
         skips=(
             # FIXME: keepdim parameter is ignored when dim=None
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9386,6 +9433,7 @@ op_db: List[OpInfo] = [
         supports_autograd=False,
         result_dtype=torch.int64,
         dtypes=all_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.argmin, supports_keepdims=False),
         skips=(
             # FIXME: keepdim parameter is ignored when dim=None
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9400,6 +9448,7 @@ op_db: List[OpInfo] = [
         result_dtype=torch.int64,
         dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
         sample_inputs_func=sample_inputs_reduction_count_nonzero,
+        ref=reference_reduction_numpy(np.count_nonzero),
         skips=(
             # FIXME: count_nonzero does not accept keepdim kwarg
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9414,6 +9463,33 @@ op_db: List[OpInfo] = [
         ),
     ),
     ReductionOpInfo(
+        'mean',
+        nan_policy='propagate',
+        supports_out=False,
+        supports_forward_ad=True,
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.mean),
+        skips=(
+            # FIXME: mean does not support passing keepdim without passing dim
+            SkipInfo('TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: mean reduces all dimensions when dim=[]
+            SkipInfo('TestReductions', 'test_dim_empty'),
+            SkipInfo('TestReductions', 'test_dim_empty_keepdim'),
+            # FIXME: mean does not support passing None to dim
+            SkipInfo('TestReductions', 'test_dim_none'),
+            SkipInfo('TestReductions', 'test_dim_none_keepdim'),
+            # FIXME: improve precision
+            SkipInfo('TestReductions', 'test_noncontiguous_all',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_small_input',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_extremal_values',
+                     device_type='cuda', dtypes=[torch.complex64]),
+        ),
+    ),
+    ReductionOpInfo(
         'prod',
         identity=1,
         nan_policy='propagate',
@@ -9424,6 +9500,7 @@ op_db: List[OpInfo] = [
         dtypes=all_types_and_complex_and(torch.bool),
         dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
         sample_inputs_func=sample_inputs_prod,
+        ref=reference_reduction_numpy(np.prod),
         skips=(
             # FIXME: prod does not support passing keepdim without passing dim
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9433,6 +9510,11 @@ op_db: List[OpInfo] = [
             # FIXME: prod does not support passing None to dim
             SkipInfo('TestReductions', 'test_dim_none'),
             SkipInfo('TestReductions', 'test_dim_none_keepdim'),
+            # FIXME: improve precision, failing with nan != inf
+            SkipInfo('TestReductions', 'test_ref_small_input',
+                     dtypes=[torch.float16, torch.complex64]),
+            SkipInfo('TestReductions', 'test_ref_duplicate_values',
+                     dtypes=[torch.uint8, torch.float16, torch.complex64]),
         ),
     ),
     ReductionOpInfo(
@@ -9443,6 +9525,7 @@ op_db: List[OpInfo] = [
         supports_forward_ad=True,
         promotes_int_to_int64=True,
         dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.sum),
         skips=(
             # FIXME: sum does not support passing keepdim without passing dim
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9452,6 +9535,13 @@ op_db: List[OpInfo] = [
             # FIXME: sum does not support passing None to dim
             SkipInfo('TestReductions', 'test_dim_none'),
             SkipInfo('TestReductions', 'test_dim_none_keepdim'),
+            # FIXME: improve precision
+            SkipInfo('TestReductions', 'test_noncontiguous_all',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_small_input',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_duplicate_values',
+                     dtypes=[torch.float16]),
         ),
     ),
     ReductionOpInfo(
@@ -9461,6 +9551,7 @@ op_db: List[OpInfo] = [
         supports_out=False,
         promotes_int_to_int64=True,
         dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+        ref=reference_reduction_numpy(np.nansum),
         skips=(
             # FIXME: nansum does not support passing keepdim without passing dim
             SkipInfo('TestReductions', 'test_dim_default_keepdim'),
@@ -9470,6 +9561,13 @@ op_db: List[OpInfo] = [
             # FIXME: nansum does not support passing None to dim
             SkipInfo('TestReductions', 'test_dim_none'),
             SkipInfo('TestReductions', 'test_dim_none_keepdim'),
+            # FIXME: improve precision
+            SkipInfo('TestReductions', 'test_noncontiguous_all',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_small_input',
+                     dtypes=[torch.float16]),
+            SkipInfo('TestReductions', 'test_ref_duplicate_values',
+                     dtypes=[torch.float16]),
         ),
     ),
     OpInfo(