Revert "Revert D30558877: Ported std/var to ReductionOpInfo (#65262)
authorNatalia Gimelshein <ngimel@fb.com>
Mon, 20 Sep 2021 17:11:29 +0000 (10:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 20 Sep 2021 17:36:06 +0000 (10:36 -0700)
Summary:
Reland of https://github.com/pytorch/pytorch/issues/63978

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65262

Reviewed By: mruberry

Differential Revision: D31037360

Pulled By: ngimel

fbshipit-source-id: 1c60f40c547229767cba3bbe7e11ca0fbbc8f95f

torch/testing/_internal/common_methods_invocations.py

index 6331c31..9d5867c 100644 (file)
@@ -5890,6 +5890,39 @@ def reference_reduction_numpy(f, supports_keepdims=True):
     return wrapper
 
 
+def reference_std_var(f):
+    """Forwards unbiased/correction kwargs as NumPy's equivalent ddof"""
+    g = reference_reduction_numpy(f)
+
+    @wraps(g)
+    def wrapper(x: np.ndarray, *args, **kwargs):
+        assert not ('unbiased' in kwargs and 'correction' in kwargs)
+
+        if 'unbiased' in kwargs:
+            kwargs['ddof'] = int(kwargs.pop('unbiased'))
+        elif 'correction' in kwargs:
+            kwargs['ddof'] = kwargs.pop('correction')
+
+        return g(x, *args, **kwargs)
+
+    return wrapper
+
+
+def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
+    """Generates unbiased/correction kwargs for std/var operators"""
+    yield ((), {'unbiased': True})
+    yield ((), {'unbiased': False})
+
+    # Currently, calling std with correction is only enabled when
+    # both dim and keepdim are provided.
+    if 'dim' in kwargs and 'keepdim' in kwargs:
+        yield ((), {'correction': 0})
+        yield ((), {'correction': 1})
+
+        numel = torch.tensor(t.shape)[kwargs.get('dim')].prod()
+        yield ((), {'correction': numel // 2})
+
+
 # Operator database (sorted alphabetically)
 op_db: List[OpInfo] = [
     UnaryUfuncInfo('abs',
@@ -7322,14 +7355,6 @@ op_db: List[OpInfo] = [
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
            )),
     OpInfo('max',
-           op=torch.max,
-           variant_test_name='binary',
-           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-           sample_inputs_func=sample_inputs_max_min_binary,
-           supports_forward_ad=True,
-           assert_autodiffed=True,),
-    OpInfo('max',
-           op=torch.max,
            variant_test_name='reduction_with_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
@@ -7338,7 +7363,6 @@ op_db: List[OpInfo] = [
                # max does not correctly warn when resizing out= inputs
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),)),
     OpInfo('max',
-           op=torch.max,
            variant_test_name='reduction_no_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            supports_out=False,
@@ -7431,14 +7455,6 @@ op_db: List[OpInfo] = [
            autodiff_nonfusible_nodes=[],
            supports_forward_ad=True),
     OpInfo('min',
-           op=torch.min,
-           variant_test_name='binary',
-           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-           sample_inputs_func=sample_inputs_max_min_binary,
-           supports_forward_ad=True,
-           assert_autodiffed=True,),
-    OpInfo('min',
-           op=torch.min,
            variant_test_name='reduction_with_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
@@ -7448,7 +7464,6 @@ op_db: List[OpInfo] = [
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
            )),
     OpInfo('min',
-           op=torch.min,
            variant_test_name='reduction_no_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            supports_out=False,
@@ -7460,16 +7475,56 @@ op_db: List[OpInfo] = [
     OpInfo('nanquantile',
            dtypes=floating_types(),
            sample_inputs_func=sample_inputs_reduction_quantile),
-    OpInfo('maximum',
-           op=torch.maximum,
-           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-           supports_forward_ad=True,
-           sample_inputs_func=sample_inputs_max_min_binary,),
-    OpInfo('minimum',
-           op=torch.minimum,
-           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-           supports_forward_ad=True,
-           sample_inputs_func=sample_inputs_max_min_binary,),
+    BinaryUfuncInfo(
+        'max',
+        aliases=('maximum',),
+        variant_test_name='binary',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        sample_inputs_func=sample_inputs_max_min_binary,
+        supports_forward_ad=True,
+        assert_autodiffed=True,
+        ref=np.maximum,
+        skips=(
+            # FIXME: maximum does not accept scalar inputs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'),
+        ),
+    ),
+    BinaryUfuncInfo(
+        'maximum',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        supports_forward_ad=True,
+        sample_inputs_func=sample_inputs_max_min_binary,
+        ref=np.maximum,
+        skips=(
+            # FIXME: maximum does not accept scalar inputs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'),
+        ),
+    ),
+    BinaryUfuncInfo(
+        'min',
+        aliases=('minimum',),
+        variant_test_name='binary',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        sample_inputs_func=sample_inputs_max_min_binary,
+        supports_forward_ad=True,
+        assert_autodiffed=True,
+        ref=np.minimum,
+        skips=(
+            # FIXME: min does not accept scalar inputs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'),
+        ),
+    ),
+    BinaryUfuncInfo(
+        'minimum',
+        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+        supports_forward_ad=True,
+        sample_inputs_func=sample_inputs_max_min_binary,
+        ref=np.minimum,
+        skips=(
+            # FIXME: minimum does not accept scalar inputs
+            DecorateInfo(unittest.skip("Skipped!"), 'TestBinaryUfuncs', 'test_broadcast_python_scalar'),
+        ),
+    ),
     # `softmax` supports different dtypes based on whether `dtype` argument,
     # is passed or not. Hence two OpInfo entries, one with dtype and other without.
     OpInfo('softmax',
@@ -8246,16 +8301,6 @@ op_db: List[OpInfo] = [
            sample_inputs_func=sample_inputs_legacy_solve,
            check_batched_gradgrad=False,
            decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
-    OpInfo('std',
-           dtypes=floating_and_complex_types_and(torch.half),
-           dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
-           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
-           backward_dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
-           sample_inputs_func=sample_inputs_std_var,
-           # TODO: std does support out in some signatures
-           supports_out=False,
-           assert_autodiffed=True,
-           ),
     UnaryUfuncInfo('tan',
                    ref=np.tan,
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
@@ -9034,17 +9079,6 @@ op_db: List[OpInfo] = [
            assert_jit_shape_analysis=True,
            assert_autodiffed=True,
            sample_inputs_func=sample_unsqueeze),
-    OpInfo('var',
-           dtypes=floating_and_complex_types_and(torch.half),
-           dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
-           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
-           backward_dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
-           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
-           sample_inputs_func=sample_inputs_std_var,
-           # TODO: revisit, some var signatures do support out (see std, too)
-           supports_out=False,
-           assert_autodiffed=True,
-           ),
     OpInfo('xlogy',
            aliases=('special.xlogy',),
            dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
@@ -9670,6 +9704,68 @@ op_db: List[OpInfo] = [
         ),
     ),
     ReductionOpInfo(
+        'std',
+        nan_policy='propagate',
+        supports_out=False,
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        dtypes=floating_and_complex_types_and(torch.half),
+        dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_std_var,
+        ref=reference_std_var(np.std),
+        generate_args_kwargs=generate_std_var_kwargs,
+        skips=(
+            # FIXME: cannot specify keepdim without dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: dim=None not supported
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # TODO(@heitorschueroff) std return float for complex types
+            # need to find a better way to model result dtype
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_result_dtype'),
+            # FIXME: improve precision
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'),
+            # NumPy is giving NaN for this
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'),
+        ),
+    ),
+    ReductionOpInfo(
+        'var',
+        nan_policy='propagate',
+        supports_out=False,
+        assert_autodiffed=True,
+        promotes_int_to_float=True,
+        dtypes=floating_and_complex_types_and(torch.half),
+        dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+        sample_inputs_func=sample_inputs_std_var,
+        ref=reference_std_var(np.var),
+        generate_args_kwargs=generate_std_var_kwargs,
+        skips=(
+            # FIXME: cannot specify keepdim without dim
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
+            # FIXME: dim=None not supported
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
+            # FIXME: dim=[] reduces all dimensions
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
+            # TODO(@heitorschueroff) std return float for complex types
+            # need to find a better way to model result dtype
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_result_dtype'),
+            # FIXME: improve precision
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'),
+            # NumPy is giving NaN for this
+            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'),
+        ),
+    ),
+    ReductionOpInfo(
         'prod',
         identity=1,
         nan_policy='propagate',