Remove `run_functional_checks` from `test_autograd` and create necessary OpInfos...
authorsoulitzer <soulitzer@gmail.com>
Wed, 15 Sep 2021 19:43:54 +0000 (12:43 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 19:45:38 +0000 (12:45 -0700)
Summary:
OpInfo tracker: https://github.com/pytorch/pytorch/issues/54261

 - Eliminate duplicated testing logic in test_autograd
 - Moved tests that rely on this testing logic to use OpInfos
   - `cat` already has OpInfo (no action needed)
   - Created OpInfo for `block_diag` and `broadcast_tensors`

Running into some FX errors. Added op to skip-list and created an issue here: https://github.com/pytorch/pytorch/issues/64997
Both `block_diag` and `broadcast_tensors` are variadic, so skipping `test_variant_consistency_jit` (from comments on other OpInfos, it looks like JIT does not support variadic tensors)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64993

Reviewed By: jbschlosser

Differential Revision: D30961736

Pulled By: soulitzer

fbshipit-source-id: e169305384a683acae1178c4e12e9e214a67226a

test/test_autograd.py
test/test_fx.py
test/test_fx_experimental.py
torch/testing/_internal/common_methods_invocations.py

index 73f9c0e..8c563da 100644 (file)
@@ -34,10 +34,7 @@ from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoL
 from torch.autograd import Variable, Function, detect_anomaly, kineto_available
 from torch.autograd.function import InplaceFunction
 import torch.autograd.forward_ad as fwAD
-from torch.testing._internal.common_methods_invocations import (
-    unpack_variables,
-    mask_not_all_zeros,
-    S)
+from torch.testing._internal.common_methods_invocations import mask_not_all_zeros
 from torch.testing._internal.common_device_type import (instantiate_device_type_tests, skipCUDAIfRocm,
                                                         onlyCPU, onlyCUDA, onlyOnCPUAndCUDA, dtypes, dtypesIfCUDA,
                                                         deviceCountAtLeast, skipCUDAIfCudnnVersionLessThan,
@@ -46,8 +43,6 @@ from torch.testing._internal.common_dtype import get_all_dtypes
 
 import pickle
 
-PRECISION = 1e-4
-
 
 def graph_desc(fn):
     if fn is None:
@@ -2716,37 +2711,6 @@ class TestAutograd(TestCase):
         with self.assertRaisesRegex(Exception, 'Simulate error'):
             d.sum().backward()
 
-    # TODO: Create OpInfos for these ops
-    def test_broadcast_tensors(self):
-        f_args_variable = (torch.randn(3, dtype=torch.double, requires_grad=True),
-                           torch.randn(1, 2, 1, dtype=torch.double, requires_grad=True),
-                           torch.randn(1, 1, dtype=torch.double, requires_grad=True),
-                           torch.randn(5, 1, 1, dtype=torch.double, requires_grad=True))
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_broadcast_tensors", "broadcast",
-                              lambda a, b, c, d: torch.broadcast_tensors(a, b, c, d),
-                              True, f_args_variable, f_args_tensor)
-
-    def test_block_diag(self):
-        f_args_variable = (torch.randn(1, S, dtype=torch.double, requires_grad=True),
-                           torch.randn(2, S, dtype=torch.double, requires_grad=True),
-                           torch.randn(3, S, dtype=torch.double, requires_grad=True))
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_block_diag", "block_diag",
-                              lambda a, b, c: torch.block_diag(a, b, c),
-                              True, f_args_variable, f_args_tensor)
-
-    def test_cat_empty_legacy(self):
-        f_args_variable = (torch.randn(0, dtype=torch.double, requires_grad=True),
-                           torch.randn(S, S, dtype=torch.double, requires_grad=True))
-        # gradgradcheck doesn't work, probably because legacy size tracking is wrong somewhere,
-        # hence False passed below, but gradcheck checked explicitly.
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_cat_empty_legacy", "cat",
-                              lambda a, b: torch.cat((a, b)),
-                              False, f_args_variable, f_args_tensor, check_forward_ad=True)
-        self.assertTrue(gradcheck(lambda a, b: torch.cat((a, b)), f_args_variable, eps=1e-6, atol=PRECISION))
-
     def test_var_mean_differentiable(self):
         dim = [2, 4]
         keepdim = False
@@ -6061,59 +6025,6 @@ def bernoulli_scalar():
     return torch.tensor(0, dtype=torch.uint8).bernoulli_()
 
 
-def gradgradcheck_method_precision_override(test_name):
-    # these are just empirical observations, we should improve
-    gradgradcheck_precision_override = {
-        'test_norm': {'atol': 2e-2, 'rtol': 1e-2},
-        'test_norm_1_5': {'atol': 1.5e-2, 'rtol': 1e-2},
-        'test_norm_3': {'atol': 5e-2, 'rtol': 1e-2},
-        'test_dist': {'atol': 5e-2, 'rtol': 1e-2},
-        'test_dist_4': {'atol': 8e-2, 'rtol': 1e-2},
-    }
-    non_broadcasted_test_name = test_name.split("_broadcast")[0]
-    override = gradgradcheck_precision_override.get(non_broadcasted_test_name)
-    if override:
-        if 'broadcast_lhs' in test_name or 'broadcast_rhs' in test_name:
-            # errors accumulated across 1 dimension
-            override = {'atol': override['atol'] * S, 'rtol': override['atol'] * S}
-        elif 'broadcast_all' in test_name:
-            # errors accumulated across multiple dimensions
-            override = {'atol': override['atol'] * S * S, 'rtol': override['atol'] * S * S}
-    return override
-
-def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable,
-                                 input_variables, run_gradgradcheck=True, check_batched_grad=True,
-                                 check_forward_ad=False):
-    test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION,
-                                   check_batched_grad=check_batched_grad, check_forward_ad=check_forward_ad))
-    gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name)
-    if gradgradcheck_precision_override is not None:
-        atol = gradgradcheck_precision_override['atol']
-        rtol = gradgradcheck_precision_override['rtol']
-        test_case.assertTrue(gradgradcheck(apply_method, input_variables, None, atol=atol, rtol=rtol,
-                                           gen_non_contig_grad_outputs=True,
-                                           check_batched_grad=check_batched_grad))
-    else:
-        test_case.assertTrue(gradgradcheck(apply_method, input_variables,
-                                           gen_non_contig_grad_outputs=True,
-                                           check_batched_grad=check_batched_grad))
-
-
-def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
-                          f_args_variable, f_args_tensor, *, check_forward_ad=False):
-    output_variable = apply_fn(*f_args_variable)
-
-    if run_grad_checks:
-        run_grad_and_gradgrad_checks(test_case, name, test_name, apply_fn,
-                                     output_variable, f_args_variable, check_forward_ad=check_forward_ad)
-
-    self_variable = f_args_variable[0]
-    if isinstance(output_variable, torch.Tensor) and output_variable.requires_grad and self_variable is not None:
-        output_variable.backward(torch.randn_like(output_variable))
-        test_case.assertEqualTypeString(self_variable, self_variable.grad)
-        test_case.assertEqual(self_variable.size(), self_variable.grad.size())
-
-
 class TestAutogradFunctional(TestCase):
     def _assert_same_struct(self, res, base):
         # base and res should be Tensors or tuple of Tensors with the same size
index c509c3b..722f5ce 100644 (file)
@@ -3072,7 +3072,9 @@ class TestOperatorSignatures(JitTestCase):
     @ops(op_db, allowed_dtypes=(torch.float,))
     def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
         # Sorted and one entry on each line to minimize merge conflicts.
-        known_no_schema = {'cdist',
+        known_no_schema = {'block_diag',
+                           'broadcast_tensors',
+                           'cdist',
                            'contiguous',
                            'dstack',
                            'einsum',
index 0727912..f2ae6c6 100644 (file)
@@ -1459,6 +1459,9 @@ class TestNormalizeOperators(JitTestCase):
     def test_normalize_operator_exhaustive(self, device, dtype, op):
         # Sorted and one entry on each line to minimize merge conflicts.
         op_skip = {
+            # See: https://github.com/pytorch/pytorch/issues/64997
+            "block_diag",
+            "broadcast_tensors",
             "contiguous",
             "einsum",
             "expand",
index 7289750..26f64b3 100644 (file)
@@ -2043,6 +2043,26 @@ def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad, **kwargs):
             make_tensor(size, device, dtype, low=None, high=None, requires_grad=requires_grad),
             args=(shape,)) for size, shape in test_cases)
 
+def sample_inputs_broadcast_tensors(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    test_cases: Tuple[tuple] = (((3,), (1, 2, 1), (1, 1), (5, 1, 1),),)
+
+    samples: List[SampleInput] = []
+    for shape, *other_shapes in test_cases:
+        samples.append(SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)))
+
+    return samples
+
+def sample_inputs_block_diag(op_info, device, dtype, requires_grad, **kwargs):
+    make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+    test_cases: Tuple[tuple] = (((1, S), (2, S), (3, S),),)
+
+    samples: List[SampleInput] = []
+    for shape, *other_shapes in test_cases:
+        samples.append(SampleInput(make_arg(shape), args=tuple(make_arg(s) for s in other_shapes)))
+
+    return samples
+
 def sample_inputs_bitwise_shift(op_info, device, dtype, requires_grad, **kwargs):
     test_cases = (
         (S, S, S),
@@ -6221,6 +6241,24 @@ op_db: List[OpInfo] = [
            supports_out=False,
            supports_forward_ad=True,
            sample_inputs_func=sample_inputs_broadcast_to),
+    OpInfo('broadcast_tensors',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           skips=(
+               # JIT does not support variadic tensors.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
+           ),
+           sample_inputs_func=sample_inputs_broadcast_tensors),
+    OpInfo('block_diag',
+           dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
+           supports_out=False,
+           supports_forward_ad=True,
+           skips=(
+               # JIT does not support variadic tensors.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
+           ),
+           sample_inputs_func=sample_inputs_block_diag),
     OpInfo('bitwise_and',
            dtypes=integral_types_and(torch.bool),
            supports_autograd=False,