From 70a3210ecaa0162b4673f53faa17675a9d3ca8de Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Aug 2021 11:43:07 -0700 Subject: [PATCH] Add `BinaryUfuncOpInfo` and broadcasting tests (#61964) Summary: As proof of concept, this PR uses the new `BinaryUfuncOpInfo` in broadcasting tests for `add`, `sub`, `mul`, `div`, `floor_div`, and `true_div`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61964 Reviewed By: ngimel Differential Revision: D30407734 Pulled By: mruberry fbshipit-source-id: ada28994f43b0635f279f45a02ecba18bc8ee033 --- test/test_binary_ufuncs.py | 80 ++++- test/test_jit_fuser_te.py | 2 + .../_internal/common_methods_invocations.py | 333 +++++++++++++-------- 3 files changed, 287 insertions(+), 128 deletions(-) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index f952911..4995e0d 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -17,8 +17,9 @@ from torch.testing._internal.common_utils import ( from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, precisionOverride, onlyOnCPUAndCUDA, - skipCUDAIfRocm, skipIf) + skipCUDAIfRocm, skipIf, ops) from torch.testing import all_types_and_complex_and, integral_types_and +from torch.testing._internal.common_methods_invocations import binary_ufuncs if TEST_SCIPY: import scipy.special @@ -89,6 +90,74 @@ def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: # TODO: update to use opinfos consistently class TestBinaryUfuncs(TestCase): + @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) + def test_broadcasting(self, device, dtype, op): + for shape_lhs, shape_rhs in ( + ((1,), ()), + ((2,), ()), + ((1,), (2,)), + ((2,), (2,)), + ((2, 1), (2,)), + ((1, 2), (2,)), + ((3, 2), (2,)), + ((3, 2), (3, 2)), + ((1, 3, 2), (2,)), + ((1, 3, 2), (3, 2)), + ((3, 1, 2), (3, 2)), + ((1, 3, 2), (1, 3, 2)), + ((2, 3, 2), ()), + ((2, 3, 2), (2, 3, 2)), + ((3, 1, 2), (1, 3, 2)), + ): + lhs = make_tensor(shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) + rhs = make_tensor(shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + + actual = op(lhs, rhs).shape + expected = torch.broadcast_shapes(shape_lhs, shape_rhs) + + msg = ( + f"On {device}, torch.{op.name} broadcasts inputs of shapes {shape_lhs} and {shape_rhs} incorrectly: " + f"{actual} != {expected}" + ) + self.assertEqual(actual, expected, msg=msg) + + @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) + def test_broadcast_python_scalar(self, device, dtype, op): + for shape_lhs in ((), (1,), (2,), (1, 2, 3),): + lhs = make_tensor(shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) + rhs_tensor = make_tensor((), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + rhs_python = rhs_tensor.item() + + actual = op(lhs, rhs_python) + expected = op(lhs, rhs_tensor) + + self.assertEqual( + actual.shape, + expected.shape, + msg=f"On {device}, torch.{op.name} broadcasts Python scalars different than 0d tensors.", + ) + + @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) + def test_not_broadcastable(self, device, dtype, op): + for shape_lhs, shape_rhs in ( + ((2,), (3,)), + ((3, 1), (2, 1)), + ((1, 3, 2), (3,)), + ((3, 1, 2), (2, 1, 2)), + ): + lhs = make_tensor(shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs) + rhs = make_tensor(shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs) + + try: + broadcasted_shape = op(lhs, rhs).shape + except RuntimeError: + continue + + msg = ( + f"On {device}, torch.{op.name} broadcasts inputs shapes {shape_lhs} and {shape_rhs} into " + f"{broadcasted_shape}, although they are not broadcastable." + ) + raise AssertionError(msg) def test_add_broadcast_empty(self, device): # empty + empty @@ -1184,11 +1253,10 @@ class TestBinaryUfuncs(TestCase): # Also tests that reverse operations are equivalent to forward ops # NOTE: division ops are tested separately above def test_binary_ops_with_scalars(self, device): - for ops in ((operator.add, torch.add), - (operator.sub, torch.sub), - (operator.mul, torch.mul), - (operator.truediv, torch.div)): - python_op, torch_op = ops + for python_op, torch_op in ((operator.add, torch.add), + (operator.sub, torch.sub), + (operator.mul, torch.mul), + (operator.truediv, torch.div)): for a, b in product(range(-10, 10), range(-10, 10)): for op in (lambda x: x * .5, lambda x: math.floor(x)): diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 614226f..b89caca 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1929,6 +1929,8 @@ works_list = [ 'cosh', 'div.no_rounding_mode', 'div.true_rounding', + 'div.floor_rounding', + 'div.trunc_rounding', 'eq', 'erf', 'erfc', diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 873d91c..617b102 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1268,53 +1268,151 @@ def sample_inputs_linalg_vector_norm(op_info, device, dtype, requires_grad, **kw return inputs -# In order to use the kwarg alpha, partials should be used in an OpInfo's sample_inputs_func -# eg. sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2) -# Then one sample input would also be generated corresponding to the value of alpha provided. -# In the future, kwargs 'alpha_floating', 'alpha_integral' & 'alpha_complex' can be used to -# specify scalars of floating, integral & complex types as values for "alpha". -# Keyword argument `rhs_exclude_zero` is used to exclude zero values from rhs tensor argument -# This is necessary for operations like `true_divide`, where divide by zero throws an exception. -def sample_inputs_binary_pwise(op_info, device, dtype, requires_grad, extra_kwargs=None, **kwargs): - if extra_kwargs is None: - extra_kwargs = {} - - scalar = 3.14 + 3.14j if dtype.is_complex else (3.14 if dtype.is_floating_point else 3) - scalar = 1 if dtype is torch.bool else scalar - tests_list = [ - ((S, S, S), (S, S, S), False), - ((S, S, S), (S, S), False), - ((), (), False), - ((S, S, S), (), False), - ((S, S, S), scalar, False), - ((), scalar, False) - ] - tests_with_lhs_broadcasting = [ - ((S, S), (S, S, S), True), - ((), (S, S, S), True), - ((S, 1, S), (M, S), True), + +# Metadata class for binary "universal functions (ufuncs)" that accept two +# tensor and have common properties +class BinaryUfuncInfo(OpInfo): + """Operator information for 'universal binary functions (binary ufuncs).' + These are functions of two tensors with common properties like: + - they are elementwise functions + - the output shape is determined by the input shape + - they typically have method and inplace variants + - they typically support the out kwarg + - they typically have NumPy or SciPy references + See NumPy's universal function documentation + (https://numpy.org/doc/stable/reference/ufuncs.html) for more details + about the concept of ufuncs. + """ + def __init__(self, name, *, lhs_make_tensor_kwargs=None, rhs_make_tensor_kwargs=None, **kwargs): + super().__init__(name, **kwargs) + + # [lr]hs_make_tensor_kwargs are part of the OpInfo to be able to dynamically generate valid samples later on. + if lhs_make_tensor_kwargs is None: + lhs_make_tensor_kwargs = {} + self.lhs_make_tensor_kwargs = lhs_make_tensor_kwargs + + if rhs_make_tensor_kwargs is None: + rhs_make_tensor_kwargs = {} + self.rhs_make_tensor_kwargs = rhs_make_tensor_kwargs + + +def _resolve_binay_pwise_kwargs( + op_info, *, op_kwargs=None, lhs_make_tensor_kwargs=None, rhs_make_tensor_kwargs=None +): + """Resolves default values for :func:`sample_inputs_binary_pwise`. + + By default :attr:`op_kwargs`, :attr:`lhs_make_tensor_kwargs`, and :attr:`rhs_make_tensor_kwargs` are just empty + dictionaries. In case :attr:`op_info` is a :class:`BinaryUfuncInfo`, :attr:`BinaryUfuncInfo.lhs_make_tensor_kwargs` + and :attr:`BinaryUfuncInfo.rhs_make_tensor_kwargs` will be used as defaults. + """ + if op_kwargs is None: + op_kwargs = {} + if lhs_make_tensor_kwargs is None: + lhs_make_tensor_kwargs = op_info.lhs_make_tensor_kwargs if isinstance(op_info, BinaryUfuncInfo) else {} + if rhs_make_tensor_kwargs is None: + rhs_make_tensor_kwargs = op_info.rhs_make_tensor_kwargs if isinstance(op_info, BinaryUfuncInfo) else {} + + return op_kwargs, lhs_make_tensor_kwargs, rhs_make_tensor_kwargs + + +def sample_inputs_binary_pwise( + op_info, + device, + dtype, + requires_grad, + *, + python_scalars=False, + op_kwargs=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + **kwargs, +): + op_kwargs, lhs_make_tensor_kwargs, rhs_make_tensor_kwargs = _resolve_binay_pwise_kwargs( + op_info, + op_kwargs=op_kwargs, + lhs_make_tensor_kwargs=lhs_make_tensor_kwargs, + rhs_make_tensor_kwargs=rhs_make_tensor_kwargs, + ) + + scalar = make_tensor((), device=device, dtype=dtype, **rhs_make_tensor_kwargs) + if python_scalars: + scalar = scalar.item() # type: ignore[assignment] + + shapes = [ + ((), scalar), + ((S,), scalar), + ((S, 1), (S,)), + ((M, S), scalar), + ((S, M, S), (M, S)), + ((S, M, S), (S, M, S)), + ((M, 1, S), (M, S)), + ((M, 1, S), (1, M, S)), ] - test_cases = tests_list + tests_with_lhs_broadcasting # type: ignore[operator] - samples = [] - for first_shape, shape_or_scalar, broadcasts_input in test_cases: - arg = shape_or_scalar - - if isinstance(shape_or_scalar, tuple): - exclude_zero = kwargs.get('rhs_exclude_zero', False) - arg = make_tensor(shape_or_scalar, device=device, dtype=dtype, - requires_grad=requires_grad, exclude_zero=exclude_zero) - samples.append(SampleInput(make_tensor(first_shape, device=device, dtype=dtype, - requires_grad=requires_grad), - args=(arg,), kwargs=extra_kwargs, - broadcasts_input=broadcasts_input)) - # Adds an extra sample using "alpha" if it's passed in kwargs - if 'alpha' in kwargs: - a = make_tensor((S, S, S), device=device, dtype=dtype, requires_grad=requires_grad) - b = make_tensor((S, S, S), device=device, dtype=dtype, requires_grad=requires_grad) - extra_kwargs['alpha'] = kwargs['alpha'] - sample = SampleInput(a, args=(b,), kwargs=extra_kwargs) - samples.append(sample) - return tuple(samples) + + sample_inputs = [] + for shape_lhs, shape_rhs_or_scalar in shapes: + lhs = make_tensor( + shape_lhs, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **lhs_make_tensor_kwargs, + ) + if isinstance(shape_rhs_or_scalar, tuple): + # shape + rhs = make_tensor( + shape_rhs_or_scalar, + device=device, + dtype=dtype, + requires_grad=requires_grad, + **rhs_make_tensor_kwargs, + ) + broadcasts_input = torch.broadcast_shapes(shape_lhs, shape_rhs_or_scalar) != shape_lhs + else: + # scalar + rhs = shape_rhs_or_scalar # type: ignore[assignment] + broadcasts_input = False + + sample_inputs.append(SampleInput(lhs, args=(rhs,), kwargs=op_kwargs, broadcasts_input=broadcasts_input)) + return sample_inputs + + +def sample_inputs_add_sub( + op_info, + device, + dtype, + requires_grad, + python_scalars=False, + alpha=1, + op_kwargs=None, + lhs_make_tensor_kwargs=None, + rhs_make_tensor_kwargs=None, + **kwargs, +): + op_kwargs, lhs_make_tensor_kwargs, rhs_make_tensor_kwargs = _resolve_binay_pwise_kwargs( + op_info, + op_kwargs=op_kwargs, + lhs_make_tensor_kwargs=lhs_make_tensor_kwargs, + rhs_make_tensor_kwargs=rhs_make_tensor_kwargs, + ) + + sample_inputs = sample_inputs_binary_pwise( + op_info, + device, + dtype, + requires_grad, + python_scalars=python_scalars, + op_kwargs=op_kwargs, + lhs_make_tensor_kwargs=lhs_make_tensor_kwargs, + rhs_make_tensor_kwargs=rhs_make_tensor_kwargs, + **kwargs, + ) + + lhs = make_tensor((S, S), device=device, dtype=dtype, requires_grad=requires_grad, **lhs_make_tensor_kwargs) + rhs = make_tensor((S, S), device=device, dtype=dtype, requires_grad=requires_grad, **rhs_make_tensor_kwargs) + sample_inputs.append(SampleInput(lhs, args=(rhs,), kwargs=dict(op_kwargs, alpha=alpha), broadcasts_input=False)) + + return sample_inputs def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs): @@ -4045,19 +4143,6 @@ def sample_inputs_logit(op_info, device, dtype, requires_grad, **kwargs): return samples -def sample_inputs_floor_divide(op_info, device, dtype, requires_grad, **kwargs): - lhs = make_tensor((S, S, S), device, dtype, low=None, high=None, requires_grad=requires_grad) - rhs = make_tensor((S, S, S), device, dtype, low=None, high=None, requires_grad=requires_grad) - # Avoid integer divide by 0 - if not (dtype.is_floating_point or dtype.is_complex): - rhs[rhs == 0] = 1 - - return [ - SampleInput(lhs, args=(rhs,)), - SampleInput(lhs, args=(rhs[0],)), - SampleInput(lhs, args=(3.14,)), - ] - def sample_inputs_isin(op_info, device, dtype, requires_grad): element = make_tensor((L,), device, dtype, low=None, high=None, requires_grad=requires_grad) indices = torch.randint(0, L, size=[S]) @@ -5452,29 +5537,29 @@ op_db: List[OpInfo] = [ SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=[torch.cdouble]), )), - OpInfo('add', - # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate - ref=lambda input, other, *, alpha=1: np.add(input, np.multiply(alpha, other)), - dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), - assert_autodiffed=True, - sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2), - supports_inplace_autograd=False, - supports_forward_ad=True), - OpInfo('mul', - aliases=('multiply',), - dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), - assert_autodiffed=True, - supports_forward_ad=True, - sample_inputs_func=sample_inputs_binary_pwise), - OpInfo('sub', - # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate - ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), - aliases=('subtract',), - dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), - assert_autodiffed=True, - supports_forward_ad=True, - sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2), - supports_inplace_autograd=False), + BinaryUfuncInfo('add', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.add(input, np.multiply(alpha, other)), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), + assert_autodiffed=True, + sample_inputs_func=partial(sample_inputs_add_sub, alpha=2), + supports_inplace_autograd=False, + supports_forward_ad=True), + BinaryUfuncInfo('mul', + aliases=('multiply',), + dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16, torch.bool), + assert_autodiffed=True, + supports_forward_ad=True, + sample_inputs_func=sample_inputs_binary_pwise), + BinaryUfuncInfo('sub', + # NumPy has no builtin reference for the alpha kwarg, but it is easy enough to emulate + ref=lambda input, other, *, alpha=1: np.subtract(input, np.multiply(alpha, other)), + aliases=('subtract',), + dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), + assert_autodiffed=True, + supports_forward_ad=True, + sample_inputs_func=partial(sample_inputs_add_sub, alpha=2), + supports_inplace_autograd=False), OpInfo('addmm', # This addmm OpInfo is for when alpha and beta are not both equal to 1. # alpha=beta=1 is tested in the following opinfo, because that special case will @@ -6029,41 +6114,43 @@ op_db: List[OpInfo] = [ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_forward_ad=True, sample_inputs_func=sample_inputs_diff), - OpInfo('div', - aliases=('divide',), - variant_test_name='no_rounding_mode', - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_binary_pwise, rhs_exclude_zero=True), - supports_forward_ad=True, - assert_autodiffed=True), - OpInfo('div', - aliases=('divide',), - variant_test_name='trunc_rounding', - dtypes=all_types_and(torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_binary_pwise, extra_kwargs={ - "rounding_mode": 'trunc'}, rhs_exclude_zero=True), - supports_forward_ad=True, - skips=( - # Reference: https://github.com/pytorch/pytorch/issues/59174 - SkipInfo('TestJit', 'test_variant_consistency_jit'), - ), - assert_autodiffed=True), - OpInfo('div', - aliases=('divide',), - variant_test_name='floor_rounding', - dtypes=all_types_and(torch.half, torch.bfloat16), - sample_inputs_func=partial(sample_inputs_binary_pwise, extra_kwargs={ - "rounding_mode": 'floor'}, rhs_exclude_zero=True), - supports_forward_ad=True, - skips=( - # Reference: https://github.com/pytorch/pytorch/issues/59174 - SkipInfo('TestJit', 'test_variant_consistency_jit'), - ), - assert_autodiffed=True), - OpInfo('true_divide', - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), - supports_forward_ad=True, - sample_inputs_func=partial(sample_inputs_binary_pwise, rhs_exclude_zero=True)), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='no_rounding_mode', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_binary_pwise, + supports_forward_ad=True, + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='trunc_rounding', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_binary_pwise, rounding_mode="trunc"), + supports_forward_ad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/59174 + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + BinaryUfuncInfo('div', + aliases=('divide',), + variant_test_name='floor_rounding', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + sample_inputs_func=partial(sample_inputs_binary_pwise, rounding_mode="floor"), + supports_forward_ad=True, + skips=( + # Reference: https://github.com/pytorch/pytorch/issues/59174 + SkipInfo('TestJit', 'test_variant_consistency_jit'), + ), + assert_autodiffed=True, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), + BinaryUfuncInfo('true_divide', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + supports_forward_ad=True, + sample_inputs_func=sample_inputs_binary_pwise, + rhs_make_tensor_kwargs=dict(exclude_zero=True)), UnaryUfuncInfo('exp', ref=np_unary_ufunc_integer_promotion_wrapper(np.exp), dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), @@ -6316,11 +6403,12 @@ op_db: List[OpInfo] = [ dtypes=all_types_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.bfloat16, torch.float16), safe_casts_outputs=True), - OpInfo('floor_divide', - dtypes=all_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_floor_divide, - supports_autograd=False, - ), + BinaryUfuncInfo('floor_divide', + dtypes=all_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_binary_pwise, + supports_autograd=False, + rhs_make_tensor_kwargs=dict(exclude_zero=True), + ), UnaryUfuncInfo('frexp', op=torch.frexp, ref=np.frexp, @@ -8752,6 +8840,7 @@ op_db: List[OpInfo] = [ # Common operator groupings unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)] +binary_ufuncs = [op for op in op_db if isinstance(op, BinaryUfuncInfo)] spectral_funcs = [op for op in op_db if isinstance(op, SpectralFuncInfo)] sparse_unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo) and op.supports_sparse is True] shape_funcs = [op for op in op_db if isinstance(op, ShapeFuncInfo)] -- 2.7.4