From c562ebca233c3ad16357753accdd04dc1ab2ab88 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Mon, 20 Sep 2021 10:11:29 -0700 Subject: [PATCH] Revert "Revert D30558877: Ported std/var to ReductionOpInfo (#65262) Summary: Reland of https://github.com/pytorch/pytorch/issues/63978 Pull Request resolved: https://github.com/pytorch/pytorch/pull/65262 Reviewed By: mruberry Differential Revision: D31037360 Pulled By: ngimel fbshipit-source-id: 1c60f40c547229767cba3bbe7e11ca0fbbc8f95f --- .../_internal/common_methods_invocations.py | 194 +++++++++++++++------ 1 file changed, 145 insertions(+), 49 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 6331c31..9d5867c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5890,6 +5890,39 @@ def reference_reduction_numpy(f, supports_keepdims=True): return wrapper +def reference_std_var(f): + """Forwards unbiased/correction kwargs as NumPy's equivalent ddof""" + g = reference_reduction_numpy(f) + + @wraps(g) + def wrapper(x: np.ndarray, *args, **kwargs): + assert not ('unbiased' in kwargs and 'correction' in kwargs) + + if 'unbiased' in kwargs: + kwargs['ddof'] = int(kwargs.pop('unbiased')) + elif 'correction' in kwargs: + kwargs['ddof'] = kwargs.pop('correction') + + return g(x, *args, **kwargs) + + return wrapper + + +def generate_std_var_kwargs(t: torch.Tensor, **kwargs): + """Generates unbiased/correction kwargs for std/var operators""" + yield ((), {'unbiased': True}) + yield ((), {'unbiased': False}) + + # Currently, calling std with correction is only enabled when + # both dim and keepdim are provided. + if 'dim' in kwargs and 'keepdim' in kwargs: + yield ((), {'correction': 0}) + yield ((), {'correction': 1}) + + numel = torch.tensor(t.shape)[kwargs.get('dim')].prod() + yield ((), {'correction': numel // 2}) + + # Operator database (sorted alphabetically) op_db: List[OpInfo] = [ UnaryUfuncInfo('abs', @@ -7322,14 +7355,6 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), )), OpInfo('max', - op=torch.max, - variant_test_name='binary', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - sample_inputs_func=sample_inputs_max_min_binary, - supports_forward_ad=True, - assert_autodiffed=True,), - OpInfo('max', - op=torch.max, variant_test_name='reduction_with_dim', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, @@ -7338,7 +7363,6 @@ op_db: List[OpInfo] = [ # max does not correctly warn when resizing out= inputs DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),)), OpInfo('max', - op=torch.max, variant_test_name='reduction_no_dim', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), supports_out=False, @@ -7431,14 +7455,6 @@ op_db: List[OpInfo] = [ autodiff_nonfusible_nodes=[], supports_forward_ad=True), OpInfo('min', - op=torch.min, - variant_test_name='binary', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - sample_inputs_func=sample_inputs_max_min_binary, - supports_forward_ad=True, - assert_autodiffed=True,), - OpInfo('min', - op=torch.min, variant_test_name='reduction_with_dim', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, @@ -7448,7 +7464,6 @@ op_db: List[OpInfo] = [ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), )), OpInfo('min', - op=torch.min, variant_test_name='reduction_no_dim', dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), supports_out=False, @@ -7460,16 +7475,56 @@ op_db: List[OpInfo] = [ OpInfo('nanquantile', dtypes=floating_types(), sample_inputs_func=sample_inputs_reduction_quantile), - OpInfo('maximum', - op=torch.maximum, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - supports_forward_ad=True, - sample_inputs_func=sample_inputs_max_min_binary,), - OpInfo('minimum', - op=torch.minimum, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - supports_forward_ad=True, - sample_inputs_func=sample_inputs_max_min_binary,), + BinaryUfuncInfo( + 'max', + aliases=('maximum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_max_min_binary, + supports_forward_ad=True, + assert_autodiffed=True, + ref=np.maximum, + skips=( + # FIXME: maximum does not accept scalar inputs + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'), + ), + ), + BinaryUfuncInfo( + 'maximum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + sample_inputs_func=sample_inputs_max_min_binary, + ref=np.maximum, + skips=( + # FIXME: maximum does not accept scalar inputs + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'), + ), + ), + BinaryUfuncInfo( + 'min', + aliases=('minimum',), + variant_test_name='binary', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_max_min_binary, + supports_forward_ad=True, + assert_autodiffed=True, + ref=np.minimum, + skips=( + # FIXME: min does not accept scalar inputs + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'), + ), + ), + BinaryUfuncInfo( + 'minimum', + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + supports_forward_ad=True, + sample_inputs_func=sample_inputs_max_min_binary, + ref=np.minimum, + skips=( + # FIXME: minimum does not accept scalar inputs + DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'), + ), + ), # `softmax` supports different dtypes based on whether `dtype` argument, # is passed or not. Hence two OpInfo entries, one with dtype and other without. OpInfo('softmax', @@ -8246,16 +8301,6 @@ op_db: List[OpInfo] = [ sample_inputs_func=sample_inputs_legacy_solve, check_batched_gradgrad=False, decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]), - OpInfo('std', - dtypes=floating_and_complex_types_and(torch.half), - dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - backward_dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_std_var, - # TODO: std does support out in some signatures - supports_out=False, - assert_autodiffed=True, - ), UnaryUfuncInfo('tan', ref=np.tan, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), @@ -9034,17 +9079,6 @@ op_db: List[OpInfo] = [ assert_jit_shape_analysis=True, assert_autodiffed=True, sample_inputs_func=sample_unsqueeze), - OpInfo('var', - dtypes=floating_and_complex_types_and(torch.half), - dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), - dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - backward_dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - sample_inputs_func=sample_inputs_std_var, - # TODO: revisit, some var signatures do support out (see std, too) - supports_out=False, - assert_autodiffed=True, - ), OpInfo('xlogy', aliases=('special.xlogy',), dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), @@ -9670,6 +9704,68 @@ op_db: List[OpInfo] = [ ), ), ReductionOpInfo( + 'std', + nan_policy='propagate', + supports_out=False, + assert_autodiffed=True, + promotes_int_to_float=True, + dtypes=floating_and_complex_types_and(torch.half), + dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.std), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=None not supported + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # TODO(@heitorschueroff) std return float for complex types + # need to find a better way to model result dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_result_dtype'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'), + # NumPy is giving NaN for this + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'), + ), + ), + ReductionOpInfo( + 'var', + nan_policy='propagate', + supports_out=False, + assert_autodiffed=True, + promotes_int_to_float=True, + dtypes=floating_and_complex_types_and(torch.half), + dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16), + dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_std_var, + ref=reference_std_var(np.var), + generate_args_kwargs=generate_std_var_kwargs, + skips=( + # FIXME: cannot specify keepdim without dim + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), + # FIXME: dim=None not supported + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), + # FIXME: dim=[] reduces all dimensions + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), + # TODO(@heitorschueroff) std return float for complex types + # need to find a better way to model result dtype + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_result_dtype'), + # FIXME: improve precision + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'), + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'), + # NumPy is giving NaN for this + DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'), + ), + ), + ReductionOpInfo( 'prod', identity=1, nan_policy='propagate', -- 2.7.4