From: Heitor Schueroff Date: Thu, 26 Aug 2021 13:05:28 +0000 (-0700) Subject: [OpInfo] Added ReductionOpInfo subclass of OpInfo and ported sum test (#62737) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~697 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=774ae0851d98829b412e46dde85e716dad065a06;p=platform%2Fupstream%2Fpytorch.git [OpInfo] Added ReductionOpInfo subclass of OpInfo and ported sum test (#62737) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62737 ReductionOpInfo is a specialization of OpInfo for reduction operators. For now, it is designed to work with reductions that return a single tensor and that reduce all elements along one or more dimensions to a single value. In particular this excludes operators such as `max` and `min` that return multiple tensors and `quantile` that can return multiple values. fixes https://github.com/pytorch/pytorch/issues/49746 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D30406568 Pulled By: heitorschueroff fbshipit-source-id: 218b1da1902f67bcf4c3681e2a0f0029a25d51f1 --- diff --git a/test/test_ops.py b/test/test_ops.py index 76a7b6a..a6baf8d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,7 +10,7 @@ from torch.testing._internal.common_utils import \ (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor, gradcheck, gradgradcheck, IS_IN_CI, suppress_warnings) from torch.testing._internal.common_methods_invocations import \ - (op_db, _NOTHING, UnaryUfuncInfo, SpectralFuncInfo) + (op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo) from torch.testing._internal.common_device_type import \ (deviceCountAtLeast, instantiate_device_type_tests, ops, onlyCUDA, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes) from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference @@ -27,8 +27,8 @@ _variant_ops = partial(ops, dtypes=OpDTypes.supported, # Get names of all the operators which have ref in their entry in OpInfo (testing infra) # except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py) # and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) -_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, SpectralFuncInfo)) and - op.ref is not None and op.ref is not _NOTHING, op_db)) +_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, ReductionOpInfo, + SpectralFuncInfo)) and op.ref is not None and op.ref is not _NOTHING, op_db)) # Tests that apply to all operators and aren't related to any particular diff --git a/test/test_reductions.py b/test/test_reductions.py index c1da0f0..e224eae 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -14,7 +14,9 @@ from torch.testing._internal.common_utils import ( IS_WINDOWS, make_tensor) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, - onlyOnCPUAndCUDA, onlyCUDA, largeTensorTest, precisionOverride) + onlyOnCPUAndCUDA, onlyCUDA, largeTensorTest, ops, precisionOverride) +from torch.testing._internal.common_methods_invocations import ( + ReductionOpInfo, reduction_ops) # TODO: replace with make_tensor def _generate_input(shape, dtype, device, with_extremal): @@ -55,6 +57,21 @@ def _rand_shape(dim, min_size, max_size): class TestReductions(TestCase): + ########################################################################### + # ReductionOpInfo unit tests + ########################################################################### + + @ops(reduction_ops, allowed_dtypes=[torch.float]) + def test_dim_default(self, device, dtype, op: ReductionOpInfo): + """Tests that the default behavior is to reduce all dimensions.""" + t = make_tensor((2, 3), device, dtype) + args, kwargs = next(op.generate_args_kwargs(t)) + self.assertEqual(op(t, *args, **kwargs).ndim, 0) + + ########################################################################### + # TODO: Legacy tests - port to ReductionOpInfo + ########################################################################### + def test_var_unbiased(self, device): tensor = torch.randn(100, device=device) self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b725c48..3839b2e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -13,7 +13,7 @@ import numpy as np from torch._six import inf import collections.abc -from typing import List, Sequence, Tuple, Union +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from torch.testing import \ (make_non_contiguous, floating_types, floating_types_and, complex_types, @@ -43,6 +43,15 @@ if TEST_SCIPY: import scipy.special +# Reasonable testing sizes for dimensions +L = 20 +M = 10 +S = 5 + +# Unique value to distinguish default from anything else +_NOTHING = object() + + class DecorateInfo(object): """Describes which test, or type of tests, should be wrapped in the given decorators when testing an operator. Any test that matches all provided @@ -92,6 +101,7 @@ class SkipInfo(DecorateInfo): device_type=device_type, dtypes=dtypes, active_if=active_if) + class SampleInput(object): """Represents sample inputs to a function.""" @@ -185,6 +195,7 @@ class SampleInput(object): sample_np_input, np_args, np_kwargs = to_numpy(self.input), to_numpy(self.args), to_numpy(self.kwargs) return (sample_np_input, np_args, np_kwargs) + class AliasInfo(object): """Class holds alias information. For example, torch.abs -> torch.absolute, torch.Tensor.absolute, torch.Tensor.absolute_ @@ -200,9 +211,6 @@ class AliasInfo(object): return self.op(*args, **kwargs) -_NOTHING = object() # Unique value to distinguish default from anything else - - # Extension of getattr to support qualified names # e.g. _getattr_qual(torch, 'linalg.norm') -> torch.linalg.norm def _getattr_qual(obj, name, default=_NOTHING): @@ -770,9 +778,164 @@ class OpInfo(object): else supported.intersection(self._default_test_dtypes)) -L = 20 -M = 10 -S = 5 +def _generate_reduction_inputs(device, dtype, requires_grad): + """Generates input tensors for testing reduction operators""" + yield make_tensor([], device, dtype, requires_grad=requires_grad) + yield make_tensor([2], device, dtype, requires_grad=requires_grad) + yield make_tensor([2, 3], device, dtype, requires_grad=requires_grad, noncontiguous=True) + yield make_tensor([3, 2, 1, 5], device, dtype, requires_grad=requires_grad) + + +def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): + """Generates a subset of all valid dim and keepdim kwargs given ndim that + is appropriate for testing reduction operators. + """ + + # Test default dim and keepdim + yield {} + + # Test reducing inner and outer most dimensions + yield {'dim': 0, 'keepdim': True} + yield {'dim': -1, 'keepdim': False} + + # Test reducing middle dimension + if ndim > 2: + yield {'dim': ndim // 2, 'keepdim': True} + + if supports_multiple_dims: + # Test reducing all dimensions + yield {'dim': tuple(range(ndim)), 'keepdim': False} + + # Test reducing both first and last dimensions + if ndim > 1: + yield {'dim': (0, -1), 'keepdim': True} + + # Test reducing every other dimension starting with the second + if ndim > 3: + yield {'dim': tuple(range(1, ndim, 2)), 'keepdim': False} + + +def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for reduction operators.""" + + # TODO(@heitorschueroff) Once all reduction operators are using + # ReductionOpInfo use op_info.supports_multiple_dims directly. + supports_multiple_dims: bool = kwargs.get('supports_multiple_dims', True) + + # TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo + # use op_info.genearte_args_kwargs directly. + generate_args_kwargs = kwargs.get('generate_args_kwargs', lambda *args, **kwargs: (yield tuple(), {})) + + inputs: List[SampleInput] = [] + for t in _generate_reduction_inputs(device, dtype, requires_grad): + for reduction_kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): + for args, kwargs in generate_args_kwargs(t, **reduction_kwargs): + kwargs.update(reduction_kwargs) + inputs.append(SampleInput(t, args=args, kwargs=kwargs)) + + return inputs + + +# NOTE [Reductions]: +# +# For testing purposes, we relax the definition of a reduction operator +# as defined in the docstring below. We do this to capture operators with +# a similar API so they can be tested automatically. However... +# +# Strictly speaking a reduction operator is an operator that can reduce an +# array to a single scalar value and that can be computed from the partial +# result of reducing subarrays. This usually means that the reduction operation +# should be commutative and associative. This definition is important when it +# comes to implementation as it determines how a reduction can be parallelized. +# +# For example, many summary statistics such as median, mode and quantile cannot +# be computed from partial results because these are sorting and counting based +# algorithms that need information that would be lost in the reduced value. +class ReductionOpInfo(OpInfo): + """Reduction operator information. + + An operator is a reduction operator if it reduces one or more dimensions of + the input tensor to a single value. Reduction operators must implement the + following signature: + + - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor` + + ReductionOpInfo tests that reduction operators implement a consistent API. + Optional features such as reducing over multiple dimensions are captured in + the optional keyword parameters of the ReductionOpInfo constructor. + + If a reduction operator does not yet implement the full required API of + reduction operators, this should be documented by skipping the failing + tests rather than adding optional parameters to ReductionOpInfo. + + NOTE + The API for reduction operators has not yet been finalized and some + requirements may change. + + See tests in test/test_reductions.py + """ + + def __init__( + self, name, *, + + # The identity value for the operator if it has one. + identity: Optional[Any] = None, + + # The nan policy for the operator if it implements one. + # - propagate: NaN values are propagated to the output + # - omit: NaN values are discarded during the reduction + nan_policy: Optional[str] = None, + + # Whether the operator supports reducing multiple dimensions. + supports_multiple_dims: bool = True, + + # Whether the operator promotes integral to floating point dtypes. + promotes_int_to_float: bool = False, + + # Whether the operator promotes all integral dtypes to int64. + promotes_int_to_int64: bool = False, + + # If a specific dtype is given, then the operator always returns that + # dtype irrespective of the input dtype. If None, the operator returns + # the dtype according to the type promotion rules above. + result_dtype: Optional[torch.dtype] = None, + + # ReductionOpInfo tests generate their own input, dim and keepdim + # arguments and call this function to generate tuples of extra args and + # kwargs to use when calling the op. This is required for operators that + # have other required parameters besides the input tensor. + generate_args_kwargs: Callable = lambda t, dim=None, keepdim=False: (yield tuple(), {}), + + # Options from the OpInfo base class + **kwargs, + ): + assert nan_policy in (None, 'propagate', 'omit') + + # These are mutually exclusive options + assert not (result_dtype and promotes_int_to_float) + assert not (result_dtype and promotes_int_to_int64) + assert not (promotes_int_to_float and promotes_int_to_int64) + + # Default sample_inputs_func for ReductionOpInfo which augments sample + # inputs from sample_inputs_reduction with the args and kwargs from + # generate_args_kwargs. This is only used if sample_inputs_func is None. + def sample_inputs_func(*args, **kwargs): + kwargs['supports_multiple_dims'] = supports_multiple_dims + kwargs['generate_args_kwargs'] = generate_args_kwargs + return sample_inputs_reduction(*args, **kwargs) + + # Override OpInfo defaults and call base class __init__ + kwargs.setdefault('inplace_variant', None) + kwargs.setdefault('sample_inputs_func', sample_inputs_func) + super(ReductionOpInfo, self).__init__(name, **kwargs) + + self.identity = identity + self.nan_policy = nan_policy + self.supports_multiple_dims = supports_multiple_dims + self.promotes_int_to_float = promotes_int_to_float + self.promotes_int_to_int64 = promotes_int_to_int64 + self.result_dtype = result_dtype + self.generate_args_kwargs = generate_args_kwargs def sample_inputs_unary(op_info, device, dtype, requires_grad, **kwargs): @@ -2452,56 +2615,6 @@ def sample_inputs_max_min_reduction_no_dim(op_info, device, dtype, requires_grad requires_grad=requires_grad),)) return inputs -# Generates input tensors for testing reduction ops -def _generate_reduction_inputs(device, dtype, requires_grad): - yield make_tensor((), device, dtype, requires_grad=requires_grad) - yield make_tensor((2,), device, dtype, requires_grad=requires_grad) - yield make_tensor((2, 3), device, dtype, requires_grad=requires_grad, noncontiguous=True) - yield make_tensor((3, 2, 1, 2, 2), device, dtype, requires_grad=requires_grad) - -# Generates a subset of possible dim and keepdim kwargs for a tensor -# with ndim dims appropriate for testing. If supports_multiple_dims -# is True (default) then dim kwarg can be a list of dims. -def _generate_reduction_kwargs(ndim, supports_multiple_dims=True): - for keepdim in [True, False]: - # Always test reducing inner and outer most dimensions - yield {'dim': 0, 'keepdim': keepdim} - yield {'dim': -1, 'keepdim': keepdim} - - # Also reduce middle dimension - if ndim > 2: - yield {'dim': ndim // 2, 'keepdim': keepdim} - - if supports_multiple_dims: - # Always test reducing all dims - yield {'dim': tuple(range(ndim)), 'keepdim': keepdim} - - # Test reducing both first and last dimensions - if ndim > 1: - yield {'dim': (0, ndim - 1), 'keepdim': keepdim} - - # Test reducing every other dimension starting with the second - if ndim > 3: - yield {'dim': tuple(range(1, ndim, 2)), 'keepdim': keepdim} - -# Wraps sample_inputs_reduction function to provide the additional supports_multiple_dims args -def sample_inputs_reduction_wrapper(supports_multiple_dims): - # Generates sample inputs for reduction ops that contain the input tensor - # and dim and keepdim kwargs. If a reduction op needs to test additional - # args/kwargs then create a separate sample_inputs function - def fn(op_info, device, dtype, requires_grad): - inputs = [] - - for t in _generate_reduction_inputs(device, dtype, requires_grad): - # Add case without dim and keepdim kwargs - inputs.append(SampleInput(t)) - for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims): - inputs.append(SampleInput(t, kwargs=kwargs)) - - return inputs - - return fn - def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad): test_quantiles = (0.5, make_tensor((2,), device, dtype, low=0, high=1)) test_interpolations = ['linear', 'midpoint'] @@ -2513,6 +2626,8 @@ def sample_inputs_reduction_quantile(op_info, device, dtype, requires_grad): inputs.append(SampleInput(t, args=(quantiles,))) for kwargs in _generate_reduction_kwargs(t.ndim, supports_multiple_dims=False): # Interpolation kwarg for now is only supported when providing both dim and keepdim + kwargs.setdefault('dim', 0) + kwargs.setdefault('keepdim', False) for interpolation in test_interpolations: kwargs['interpolation'] = interpolation inputs.append(SampleInput(t, args=(quantiles,), kwargs=kwargs)) @@ -6875,19 +6990,19 @@ op_db: List[OpInfo] = [ dtypesIfCUDA=all_types_and(torch.float16), # TODO: some signatures of median do support out supports_out=False, - sample_inputs_func=sample_inputs_reduction_wrapper(False)), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), OpInfo('nanmedian', dtypes=all_types(), dtypesIfCPU=all_types_and(torch.bfloat16), dtypesIfCUDA=all_types_and(torch.float16), # TODO: some signatures of nanmedian do support out supports_out=False, - sample_inputs_func=sample_inputs_reduction_wrapper(False)), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False)), OpInfo('var_mean', dtypes=floating_and_complex_types_and(torch.half), dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_reduction_wrapper(False), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False), backward_dtypes=floating_types_and(torch.half), backward_dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16), backward_dtypesIfCUDA=floating_types_and(torch.half), @@ -6906,7 +7021,7 @@ op_db: List[OpInfo] = [ dtypes=floating_and_complex_types_and(torch.half), dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_reduction_wrapper(False), + sample_inputs_func=partial(sample_inputs_reduction, supports_multiple_dims=False), backward_dtypes=floating_types_and(torch.half), backward_dtypesIfCPU=floating_types_and(torch.half, torch.bfloat16), backward_dtypesIfCUDA=floating_types_and(torch.half), @@ -6981,21 +7096,16 @@ op_db: List[OpInfo] = [ supports_out=False, supports_forward_ad=True, sample_inputs_func=sample_inputs_max_min_reduction_no_dim,), - OpInfo('sum', - dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), - supports_out=False, - supports_forward_ad=True, - sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True)), OpInfo('nansum', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), supports_out=False, - sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True)), + sample_inputs_func=sample_inputs_reduction), # TODO(@heitorschueroff) Add test for dtype kwarg OpInfo('mean', dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), assert_autodiffed=True, supports_forward_ad=True, - sample_inputs_func=sample_inputs_reduction_wrapper(supports_multiple_dims=True), + sample_inputs_func=sample_inputs_reduction, # Need to skip out test because one of the overload for mean does not support it # TODO(@heitorschueroff) fix this when implementing ReductionInfo skips=(SkipInfo('TestCommon', 'test_out'),)), @@ -8843,6 +8953,14 @@ op_db: List[OpInfo] = [ ), ), ), + ReductionOpInfo( + 'sum', + identity=0, + supports_out=False, + supports_forward_ad=True, + promotes_int_to_int64=True, + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + ), ] # Common operator groupings @@ -8851,6 +8969,7 @@ binary_ufuncs = [op for op in op_db if isinstance(op, BinaryUfuncInfo)] spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)] sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse is True] shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)] +reduction_ops = [op for op in op_db if isinstance(op, ReductionOpInfo)] # TODO: review porting these to make_tensor def index_variable(shape, max_indices, device=torch.device('cpu')):