Revert D30558877: Ported std/var to ReductionOpInfo and minimum/maximum to BinaryUfun...
authorSupriya Rao <supriyar@fb.com>
Wed, 15 Sep 2021 00:32:15 +0000 (17:32 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 00:33:38 +0000 (17:33 -0700)
Test Plan: revert-hammer

Differential Revision:
D30558877 (https://github.com/pytorch/pytorch/commit/382e008fbf5cc91c283fc902bb0dd6cb7d4bbfda)

Original commit changeset: 3e62ff24a935

fbshipit-source-id: 3b9f03c1f43c6d5f2738ed139d0236f2ded78dbf

torch/testing/_internal/common_methods_invocations.py

index 2297337..5dd1cb2 100644 (file)
@@ -5810,39 +5810,6 @@ def reference_reduction_numpy(f, supports_keepdims=True):
     return wrapper
 
 
-def reference_std_var(f):
-    """Forwards unbiased/correction kwargs as NumPy's equivalent ddof"""
-    g = reference_reduction_numpy(f)
-
-    @wraps(g)
-    def wrapper(x: np.ndarray, *args, **kwargs):
-        assert not ('unbiased' in kwargs and 'correction' in kwargs)
-
-        if 'unbiased' in kwargs:
-            kwargs['ddof'] = int(kwargs.pop('unbiased'))
-        elif 'correction' in kwargs:
-            kwargs['ddof'] = kwargs.pop('correction')
-
-        return g(x, *args, **kwargs)
-
-    return wrapper
-
-
-def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
-    """Generates unbiased/correction kwargs for std/var operators"""
-    yield ((), {'unbiased': True})
-    yield ((), {'unbiased': False})
-
-    # Currently, calling std with correction is only enabled when
-    # both dim and keepdim are provided.
-    if 'dim' in kwargs and 'keepdim' in kwargs:
-        yield ((), {'correction': 0})
-        yield ((), {'correction': 1})
-
-        numel = torch.tensor(t.shape)[kwargs.get('dim')].prod()
-        yield ((), {'correction': numel // 2})
-
-
 # Operator database (sorted alphabetically)
 op_db: List[OpInfo] = [
     UnaryUfuncInfo('abs',
@@ -7257,6 +7224,14 @@ op_db: List[OpInfo] = [
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
            )),
     OpInfo('max',
+           op=torch.max,
+           variant_test_name='binary',
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           sample_inputs_func=sample_inputs_max_min_binary,
+           supports_forward_ad=True,
+           assert_autodiffed=True,),
+    OpInfo('max',
+           op=torch.max,
            variant_test_name='reduction_with_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
@@ -7265,6 +7240,7 @@ op_db: List[OpInfo] = [
                # max does not correctly warn when resizing out= inputs
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),)),
     OpInfo('max',
+           op=torch.max,
            variant_test_name='reduction_no_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            supports_out=False,
@@ -7359,6 +7335,14 @@ op_db: List[OpInfo] = [
            autodiff_nonfusible_nodes=[],
            supports_forward_ad=True),
     OpInfo('min',
+           op=torch.min,
+           variant_test_name='binary',
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           sample_inputs_func=sample_inputs_max_min_binary,
+           supports_forward_ad=True,
+           assert_autodiffed=True,),
+    OpInfo('min',
+           op=torch.min,
            variant_test_name='reduction_with_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            sample_inputs_func=sample_inputs_max_min_reduction_with_dim,
@@ -7368,6 +7352,7 @@ op_db: List[OpInfo] = [
                DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'),
            )),
     OpInfo('min',
+           op=torch.min,
            variant_test_name='reduction_no_dim',
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            supports_out=False,
@@ -7379,56 +7364,16 @@ op_db: List[OpInfo] = [
     OpInfo('nanquantile',
            dtypes=floating_types(),
            sample_inputs_func=sample_inputs_reduction_quantile),
-    BinaryUfuncInfo(
-        'max',
-        aliases=('maximum',),
-        variant_test_name='binary',
-        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-        sample_inputs_func=sample_inputs_max_min_binary,
-        supports_forward_ad=True,
-        assert_autodiffed=True,
-        ref=np.maximum,
-        skips=(
-            # FIXME: maximum does not accept scalar inputs
-            SkipInfo('TestBinaryUfuncs', 'test_broadcast_python_scalar'),
-        ),
-    ),
-    BinaryUfuncInfo(
-        'maximum',
-        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-        supports_forward_ad=True,
-        sample_inputs_func=sample_inputs_max_min_binary,
-        ref=np.maximum,
-        skips=(
-            # FIXME: maximum does not accept scalar inputs
-            SkipInfo('TestBinaryUfuncs', 'test_broadcast_python_scalar'),
-        ),
-    ),
-    BinaryUfuncInfo(
-        'min',
-        aliases=('minimum',),
-        variant_test_name='binary',
-        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-        sample_inputs_func=sample_inputs_max_min_binary,
-        supports_forward_ad=True,
-        assert_autodiffed=True,
-        ref=np.minimum,
-        skips=(
-            # FIXME: min does not accept scalar inputs
-            SkipInfo('TestBinaryUfuncs', 'test_broadcast_python_scalar'),
-        ),
-    ),
-    BinaryUfuncInfo(
-        'minimum',
-        dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
-        supports_forward_ad=True,
-        sample_inputs_func=sample_inputs_max_min_binary,
-        ref=np.minimum,
-        skips=(
-            # FIXME: minimum does not accept scalar inputs
-            SkipInfo('TestBinaryUfuncs', 'test_broadcast_python_scalar'),
-        ),
-    ),
+    OpInfo('maximum',
+           op=torch.maximum,
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           supports_forward_ad=True,
+           sample_inputs_func=sample_inputs_max_min_binary,),
+    OpInfo('minimum',
+           op=torch.minimum,
+           dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
+           supports_forward_ad=True,
+           sample_inputs_func=sample_inputs_max_min_binary,),
     # `softmax` supports different dtypes based on whether `dtype` argument,
     # is passed or not. Hence two OpInfo entries, one with dtype and other without.
     OpInfo('softmax',
@@ -8196,6 +8141,16 @@ op_db: List[OpInfo] = [
            sample_inputs_func=sample_inputs_legacy_solve,
            check_batched_gradgrad=False,
            decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
+    OpInfo('std',
+           dtypes=floating_and_complex_types_and(torch.half),
+           dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           backward_dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_std_var,
+           # TODO: std does support out in some signatures
+           supports_out=False,
+           assert_autodiffed=True,
+           ),
     UnaryUfuncInfo('tan',
                    ref=np.tan,
                    dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
@@ -8974,6 +8929,17 @@ op_db: List[OpInfo] = [
            assert_jit_shape_analysis=True,
            assert_autodiffed=True,
            sample_inputs_func=sample_unsqueeze),
+    OpInfo('var',
+           dtypes=floating_and_complex_types_and(torch.half),
+           dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           backward_dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+           sample_inputs_func=sample_inputs_std_var,
+           # TODO: revisit, some var signatures do support out (see std, too)
+           supports_out=False,
+           assert_autodiffed=True,
+           ),
     OpInfo('xlogy',
            aliases=('special.xlogy',),
            dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
@@ -9575,68 +9541,6 @@ op_db: List[OpInfo] = [
         ),
     ),
     ReductionOpInfo(
-        'std',
-        nan_policy='propagate',
-        supports_out=False,
-        assert_autodiffed=True,
-        promotes_int_to_float=True,
-        dtypes=floating_and_complex_types_and(torch.half),
-        dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
-        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
-        sample_inputs_func=sample_inputs_std_var,
-        ref=reference_std_var(np.std),
-        generate_args_kwargs=generate_std_var_kwargs,
-        skips=(
-            # FIXME: cannot specify keepdim without dim
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
-            # FIXME: dim=None not supported
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
-            # FIXME: dim=[] reduces all dimensions
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
-            # TODO(@heitorschueroff) std return float for complex types
-            # need to find a better way to model result dtype
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_result_dtype'),
-            # FIXME: improve precision
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'),
-            # NumPy is giving NaN for this
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'),
-        ),
-    ),
-    ReductionOpInfo(
-        'var',
-        nan_policy='propagate',
-        supports_out=False,
-        assert_autodiffed=True,
-        promotes_int_to_float=True,
-        dtypes=floating_and_complex_types_and(torch.half),
-        dtypesIfCPU=floating_and_complex_types_and(torch.half, torch.bfloat16),
-        dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
-        sample_inputs_func=sample_inputs_std_var,
-        ref=reference_std_var(np.var),
-        generate_args_kwargs=generate_std_var_kwargs,
-        skips=(
-            # FIXME: cannot specify keepdim without dim
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
-            # FIXME: dim=None not supported
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
-            # FIXME: dim=[] reduces all dimensions
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
-            # TODO(@heitorschueroff) std return float for complex types
-            # need to find a better way to model result dtype
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_result_dtype'),
-            # FIXME: improve precision
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input'),
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values'),
-            # NumPy is giving NaN for this
-            DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_large_input'),
-        ),
-    ),
-    ReductionOpInfo(
         'prod',
         identity=1,
         nan_policy='propagate',