From 950f7c023706db6db64a37b7b3c8b760679f7d3f Mon Sep 17 00:00:00 2001 From: Heitor Schueroff Date: Thu, 26 Aug 2021 07:17:24 -0700 Subject: [PATCH] Added API tests to ReductionOpInfo and ported amax/amin/nansum tests (#62899) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62899 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D30408816 Pulled By: heitorschueroff fbshipit-source-id: 6cb0aa7fa7edba93549ef873baa2fb8a003bd91d --- test/test_reductions.py | 242 ++++++++++++++++++++- .../_internal/common_methods_invocations.py | 241 +++++++++++++------- 2 files changed, 397 insertions(+), 86 deletions(-) diff --git a/test/test_reductions.py b/test/test_reductions.py index e224eae..e716336 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -2,18 +2,20 @@ import torch import numpy as np import math -from typing import Dict, List +from typing import Dict, List, Sequence import random from functools import partial from itertools import product, combinations, permutations import warnings from torch._six import inf, nan +from torch.testing import ( + integral_types_and, floating_and_complex_types_and) from torch.testing._internal.common_utils import ( TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, IS_WINDOWS, make_tensor) from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, + OpDTypes, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, onlyOnCPUAndCUDA, onlyCUDA, largeTensorTest, ops, precisionOverride) from torch.testing._internal.common_methods_invocations import ( ReductionOpInfo, reduction_ops) @@ -55,18 +57,244 @@ def _rand_shape(dim, min_size, max_size): shape.append(random.randint(min_size, max_size)) return tuple(shape) +def _reduced_shape(shape, dim=None, keepdim=False): + """Computes the expected reduced shape given dim and keepdim + + Args: + shape: The shape to reduce + dim : The dimensions to reduce + keepdim: If true, reduced dimensions have size 1 in the reduced shape, + otherwise they are removed from the reduced shape. + + Returns: + The reduced shape + """ + if dim is None: + return [1] * len(shape) if keepdim else [] + + # Wrap negative dims + dim = dim if isinstance(dim, Sequence) else [dim] + dim = set(i if i >= 0 else len(shape) + i for i in dim) + + result = [] + for i, size in enumerate(shape): + if i not in dim: + result.append(size) + elif keepdim: + result.append(1) + + return result + class TestReductions(TestCase): ########################################################################### # ReductionOpInfo unit tests ########################################################################### - @ops(reduction_ops, allowed_dtypes=[torch.float]) - def test_dim_default(self, device, dtype, op: ReductionOpInfo): - """Tests that the default behavior is to reduce all dimensions.""" - t = make_tensor((2, 3), device, dtype) + def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim): + """Tests output shape for input with ndim and dim and keepdim kwargs""" + shape = torch.randint(2, 5, (ndim,)).tolist() + t = make_tensor(shape, device, torch.float) + args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim)) + result = op(t, *args, **dim_keepdim, **kwargs) + expected_shape = _reduced_shape(shape, **dim_keepdim) + self.assertEqual(result.shape, expected_shape, f""" + expected output shape to be {expected_shape} but got {list(result.shape)} + for input shape {shape} and {dim_keepdim} + """) + + # TODO(@heitorschueroff) combine cases with and without keepdim once + # there's support for a @parametrize decorator. + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_default(self, device, op: ReductionOpInfo): + """Tests that the default dim reduces all dimensions.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_default_keepdim(self, device, op: ReductionOpInfo): + """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_none(self, device, op: ReductionOpInfo): + """Tests that dim=None reduces all dimensions.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim, dim=None) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_none_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1.""" + for ndim in range(3): + self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_single(self, device, op: ReductionOpInfo): + """Tests that dim=i reduces dimension i.""" + self._test_dim_keepdim(op, device, ndim=0, dim=0) + self._test_dim_keepdim(op, device, ndim=1, dim=0) + self._test_dim_keepdim(op, device, ndim=2, dim=-1) + self._test_dim_keepdim(op, device, ndim=3, dim=1) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_single_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=i, when keepdim=True, reduces dimension i to size 1.""" + self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True) + self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True) + self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True) + self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_empty(self, device, op: ReductionOpInfo): + """Tests that dim=[] is a no-op""" + self._test_dim_keepdim(op, device, ndim=0, dim=[]) + self._test_dim_keepdim(op, device, ndim=2, dim=[]) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_empty_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=[], when keepdim=True, is a no-op""" + self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True) + self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi(self, device, op: ReductionOpInfo): + """Tests that dim=[i, j, ...] reduces dimensions i, j, ....""" + self._test_dim_keepdim(op, device, ndim=1, dim=[0]) + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_keepdim(self, device, op: ReductionOpInfo): + """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1.""" + self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True) + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_unsorted(self, device, op: ReductionOpInfo): + """Tests that operator correctly handles unsorted dim list.""" + self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2]) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo): + """Tests that operator correctly handles unsorted dim list when keepdim=True.""" + self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True) + + @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_duplicate(self, device, op: ReductionOpInfo): + """Tests that an error is raised if dim has duplicate entries.""" + with self.assertRaises(RuntimeError): + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2]) + + @ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) + def test_dim_multi_unsupported(self, device, op: ReductionOpInfo): + """Tests that ops claiming to not support multi dim actually don't.""" + with self.assertRaises(TypeError): + self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_offbounds(self, device, op: ReductionOpInfo): + """Tests that passing an off-bounds dim throws""" + with self.assertRaises(IndexError): + self._test_dim_keepdim(op, device, ndim=2, dim=2) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_dim_ndim_limit(self, device, op: ReductionOpInfo): + """Tests that an exception is raised when reducing a tensor with more + than 64 dims along some specific dimensions. dim=None is ok""" + t = make_tensor([1] * 65, device, torch.float) + with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): + op(t, dim=0) + + @ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported) + def test_identity(self, device, dtype, op: ReductionOpInfo): + """Tests that the identity value is an identity for the operator""" + t = make_tensor((10,), device, dtype) + t[1::2] = op.identity + args, kwargs = next(op.generate_args_kwargs(t)) + result = op(t[::2], *args, **kwargs) + result_with_identity = op(t, *args, **kwargs) + self.assertEqual(result, result_with_identity, """ + Adding identity value to the input tensor should not change the result. + """) + + # TODO(@heitorschueroff) Update these to use the nan_policy kwarg once + # it is added to reduction operators. + + @ops(filter(lambda op: op.nan_policy == 'propagate', reduction_ops), dtypes=OpDTypes.supported, + allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16)) + def test_nan_policy_propagate(self, device, dtype, op: ReductionOpInfo): + """Tests that nan is propagated to the output by default""" + t = make_tensor((5,), device, dtype) + t[2] = torch.nan args, kwargs = next(op.generate_args_kwargs(t)) - self.assertEqual(op(t, *args, **kwargs).ndim, 0) + result = op(t, *args, **kwargs) + self.assertTrue(result.isnan()) + + @ops(filter(lambda op: op.nan_policy == 'omit', reduction_ops), dtypes=OpDTypes.supported, + allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16)) + def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo): + """Tests that NaN values do not affect the result.""" + t = make_tensor((10,), device, dtype) + t[1::2] = torch.nan + args, kwargs = next(op.generate_args_kwargs(t)) + result = op(t[::2], *args, **kwargs) + result_with_nan = op(t, *args, **kwargs) + self.assertEqual(result, result_with_nan) + + @ops(reduction_ops, dtypes=OpDTypes.supported) + def test_result_dtype(self, device, dtype, op: ReductionOpInfo): + """Tests that the result has the correct dtype""" + t = make_tensor((5,), device, dtype) + args, kwargs = next(op.generate_args_kwargs(t)) + result: torch.Tensor = op(t, *args, **kwargs) + is_integral = dtype in integral_types_and(torch.bool) + if op.promotes_int_to_float and is_integral: + self.assertTrue(torch.is_floating_point(result.dtype)) + elif op.promotes_int_to_int64 and is_integral: + self.assertEqual(result.dtype, torch.int64) + elif op.result_dtype is not None: + self.assertEqual(result.dtype, op.result_dtype) + else: + self.assertEqual(result.dtype, dtype) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo): + """Tests for consistent behavior when reducing over an empty slice. + + The rules for reducing over an empty slice are as follows: + - Return the identity value if the operator has one + - Otherwise, return NaN if the operator promotes integral dtype to + floating point dtypes. + - Otherwise, raise an error + + See discussion here https://github.com/pytorch/pytorch/issues/61901 + """ + t = make_tensor((0, 2, 3), device, torch.float) + for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []: + args, kwargs = next(op.generate_args_kwargs(t, dim=dim)) + if op.identity is not None: + # Reducing along empty slice should return identity + result = op(t, *args, dim=dim, **kwargs) + self.assertEqual(result, torch.full_like(result, op.identity)) + elif op.promotes_int_to_float: + # Reducing along empty slice should return NaN + result = op(t, *args, dim=dim, **kwargs) + self.assertEqual(result, torch.full_like(result, torch.nan)) + else: + # Reducing along empty slice should raise an error + with self.assertRaises(IndexError): + op(t, *args, dim=dim, **kwargs) + + @ops(reduction_ops, dtypes=OpDTypes.none) + def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo): + """Tests that reducing a nonempty slice of an empty tensor returns an + empty tensor with the dimensions reduced.""" + t = make_tensor((0, 2, 3), device, torch.float) + for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []: + args, kwargs = next(op.generate_args_kwargs(t, dim=dim)) + result = op(t, *args, dim=dim, **kwargs) + self.assertEqual(result.shape, _reduced_shape(t.shape, dim)) ########################################################################### # TODO: Legacy tests - port to ReductionOpInfo diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3839b2e..4331c92 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2170,28 +2170,6 @@ def sample_inputs_take_along_dim(op_info, device, dtype, requires_grad, **kwargs ) -def sample_inputs_amax_amin(op_info, device, dtype, requires_grad, **kwargs): - # Ordered as (input shape, kwargs) - test_cases: Tuple[tuple, dict] = ( # type: ignore[assignment] - ((S, S, S), {}), - ((S, S, S), {'dim': 1}), - ((S, S, S), {'dim': (1, 2,)}), - ((S, S, S), {'dim': 1, 'keepdim': True}), - ((), {'dim': 0}), - ((), {}), - ((), {'dim': 0, 'keepdim': True}), - ) - - samples: List[SampleInput] = [] - for shape, kwargs in test_cases: - samples.append(SampleInput( - make_tensor(shape, device, dtype, requires_grad=requires_grad), - kwargs=kwargs)) - - return samples - -# TODO (@heitorschueroff) Once aminmax supports multiple dims this should -# be combined with the above test. def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): test_cases: Tuple[tuple, dict] = ( # type: ignore[assignment] ((S, S, S), {}), @@ -2210,33 +2188,6 @@ def sample_inputs_aminmax(op_info, device, dtype, requires_grad, **kwargs): return samples -def sample_inputs_argmax_argmin(op_info, device, dtype, requires_grad, **kwargs): - test_cases = ( - ((2, 2, 2), ()), - ((2, 2, 2), (0,)), - ((2, 2, 2), (1,)), - ((2, 2, 2), (2,)), - ((2, 2, 2), (2, True,)), - ((2, 2, 2), (None,)), - ((), (0,)), - ((), ()), - ((), (None, True,)), - ((1,), ()), - ((1,), (0,)), - ((1,), (0, True)), - ((2,), ()), - ((2,), (0,)), - ((2,), (0, True)), - ((2, 2, 3), ()), - ((2, 2, 3), (0,)), - ((2, 2, 3), (1,)), - ((2, 2, 3), (None, True)), - ) - return tuple(SampleInput((make_tensor(size, device, dtype, - requires_grad=requires_grad)), - args=args) - for size, args in test_cases) - def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): test_cases = ( ((1,), 0, None, None), @@ -2634,6 +2585,14 @@ def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad): return inputs +def sample_inputs_reduction_count_nonzero(*args, **kwargs): + """Sample inputs for count_nonzero""" + samples: List[SampleInput] = sample_inputs_reduction(*args, **kwargs) + # count_nonzero does not support keepdim yet + for sample in samples: + sample.kwargs.pop('keepdim', None) + return samples + def sample_inputs_leaky_relu(op_info, device, dtype, requires_grad): N = 10 tensors = [SampleInput(make_tensor((N, N), device=device, dtype=dtype, @@ -5823,22 +5782,6 @@ op_db: List[OpInfo] = [ # TODO: update sample inputs with for_inplace_variant kwarg to support this test SkipInfo('TestCommon', 'test_variant_consistency_eager'),), sample_inputs_func=sample_inputs_addcmul_addcdiv), - OpInfo('amax', - ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs), - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - sample_inputs_func=sample_inputs_amax_amin,), - OpInfo('amin', - ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs), - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - sample_inputs_func=sample_inputs_amax_amin), - OpInfo('argmax', - dtypes=all_types_and(torch.float16, torch.bfloat16), - supports_autograd=False, - sample_inputs_func=sample_inputs_argmax_argmin,), - OpInfo('argmin', - dtypes=all_types_and(torch.float16, torch.bfloat16), - supports_autograd=False, - sample_inputs_func=sample_inputs_argmax_argmin,), UnaryUfuncInfo('asin', aliases=('arcsin', ), ref=np.arcsin, @@ -7096,10 +7039,6 @@ op_db: List[OpInfo] = [ supports_out=False, supports_forward_ad=True, sample_inputs_func=sample_inputs_max_min_reduction_no_dim,), - OpInfo('nansum', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - supports_out=False, - sample_inputs_func=sample_inputs_reduction), # TODO(@heitorschueroff) Add test for dtype kwarg OpInfo('mean', dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), @@ -7458,16 +7397,6 @@ op_db: List[OpInfo] = [ supports_forward_ad=True, skips=( SkipInfo('TestMathBits', 'test_conj_view', device_type='cuda'),),), - OpInfo('prod', - dtypes=all_types_and_complex_and(torch.bool), - dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), - skips=( - # prod does not support the (Tensor, *, out) overload - SkipInfo('TestCommon', 'test_out', - dtypes=[torch.float32]), - ), - sample_inputs_func=sample_inputs_prod, - gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('qr', op=torch.qr, dtypes=floating_and_complex_types(), @@ -8954,12 +8883,166 @@ op_db: List[OpInfo] = [ ), ), ReductionOpInfo( + 'all', + identity=True, + supports_multiple_dims=False, + supports_out=False, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: does not support passing keepdim without dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: does not support dim=None + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + # FIXME: uint8 input returns uint8 instead of bool + SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'any', + identity=False, + supports_multiple_dims=False, + supports_out=False, + supports_autograd=False, + result_dtype=torch.bool, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: does not support passing keepdim without dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: does not support dim=None + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + # FIXME: uint8 input returns uint8 instead of bool + SkipInfo('TestReductions', 'test_result_dtype', dtypes=[torch.uint8]), + ), + ), + ReductionOpInfo( + 'amax', + nan_policy='propagate', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs), + skips=( + # FIXME: sum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'amin', + nan_policy='propagate', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs), + skips=( + # FIXME: sum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + ), + ), + ReductionOpInfo( + 'argmax', + supports_multiple_dims=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + skips=( + # FIXME: keepdim parameter is ignored when dim=None + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( + 'argmin', + supports_multiple_dims=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and(torch.float16, torch.bfloat16), + skips=( + # FIXME: keepdim parameter is ignored when dim=None + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( + 'count_nonzero', + identity=0, + supports_out=False, + supports_autograd=False, + result_dtype=torch.int64, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_reduction_count_nonzero, + skips=( + # FIXME: count_nonzero does not accept keepdim kwarg + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + SkipInfo('TestReductions', 'test_dim_single_keepdim'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + SkipInfo('TestReductions', 'test_dim_multi_keepdim'), + SkipInfo('TestReductions', 'test_dim_multi_unsorted_keepdim'), + SkipInfo('TestReductions', 'test_dim_offbounds_keepdim'), + # FIXME: dim=[] reduces all dimensions + SkipInfo('TestReductions', 'test_dim_empty'), + ), + ), + ReductionOpInfo( + 'prod', + identity=1, + nan_policy='propagate', + supports_multiple_dims=False, + supports_out=False, + promotes_int_to_int64=True, + gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, + dtypes=all_types_and_complex_and(torch.bool), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_prod, + skips=( + # FIXME: prod does not support passing keepdim without passing dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: prod reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + # FIXME: prod does not support passing None to dim + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( 'sum', identity=0, + nan_policy='propagate', supports_out=False, supports_forward_ad=True, promotes_int_to_int64=True, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: sum does not support passing keepdim without passing dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: sum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + # FIXME: sum does not support passing None to dim + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), + ), + ReductionOpInfo( + 'nansum', + identity=0, + nan_policy='omit', + supports_out=False, + promotes_int_to_int64=True, + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + skips=( + # FIXME: nansum does not support passing keepdim without passing dim + SkipInfo('TestReductions', 'test_dim_default_keepdim'), + # FIXME: nansum reduces all dimensions when dim=[] + SkipInfo('TestReductions', 'test_dim_empty'), + SkipInfo('TestReductions', 'test_dim_empty_keepdim'), + # FIXME: nansum does not support passing None to dim + SkipInfo('TestReductions', 'test_dim_none'), + SkipInfo('TestReductions', 'test_dim_none_keepdim'), + ), ), ] -- 2.7.4