Revert D30867266: [pytorch][PR] TST Adds gradcheck and gradgradcheck to module info
authorNikita Shulga <nshulga@fb.com>
Sun, 12 Sep 2021 17:29:10 +0000 (10:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sun, 12 Sep 2021 17:30:28 +0000 (10:30 -0700)
Test Plan: revert-hammer

Differential Revision:
D30867266 (https://github.com/pytorch/pytorch/commit/67ebde56459557199b3c907b81b3c819f77500b9)

Original commit changeset: cbc073326151

fbshipit-source-id: 00234e01eafc45fb999f7c83a397f9d6b3e01e46

test/test_modules.py
torch/testing/_internal/common_modules.py

index 3effcf2..37ab347 100644 (file)
@@ -6,7 +6,7 @@ import torch
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_modules import module_db, modules
 from torch.testing._internal.common_utils import (
-    TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
+    TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from)
 from unittest.mock import patch
 
 
@@ -206,58 +206,6 @@ class TestModule(TestCase):
             self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
 
 
-    def _test_gradients_helper(self, device, dtype, module_info, check):
-        # Check gradients
-        module_cls = module_info.module_cls
-        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
-                                                       requires_grad=True)
-
-        for module_input in module_inputs:
-            if module_input.forward_input is None:
-                continue
-
-            # === Instantiate the module. ===
-            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
-            m = module_cls(*args, **kwargs)
-            m.to(device).to(dtype)
-
-            params = tuple(m.parameters())
-
-            # === Perform gradient check on the input_args ===
-            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
-
-            other_kwargs = {}
-            kwarg_tensors = []
-            for name, obj in input_kwargs.items():
-                if isinstance(obj, torch.Tensor):
-                    kwarg_tensors.append((name, obj))
-                else:
-                    other_kwargs[name] = obj
-
-            grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
-
-            def fn_to_gradcheck(*input_and_params):
-                new_input_args = input_and_params[:len(input_args)]
-                kwarg_args = input_and_params[-len(kwarg_tensors):]
-                new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
-
-                with freeze_rng_state():
-                    return m(*new_input_args, **new_kwargs, **other_kwargs)
-
-            self.assertTrue(check(fn_to_gradcheck, grad_input))
-
-
-    @modules(module_db, allowed_dtypes=[torch.double])
-    def test_grad(self, device, dtype, module_info):
-        self._test_gradients_helper(device, dtype, module_info, gradcheck)
-
-    @modules(module_db, allowed_dtypes=[torch.double])
-    def test_gradgrad(self, device, dtype, module_info):
-        if not module_info.supports_gradgrad:
-            self.skipTest("Skipped! Module does not support gradgrad")
-        self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
-
-
 instantiate_device_type_tests(TestModule, globals())
 
 if __name__ == '__main__':
index 0a36309..b1cbbb3 100644 (file)
@@ -2,7 +2,6 @@ import torch
 from copy import deepcopy
 from functools import wraps, partial
 from itertools import chain
-import torch.nn.functional as F
 from torch.testing import make_tensor
 from torch.testing._internal.common_dtype import floating_types
 from torch.testing._internal.common_device_type import (
@@ -49,18 +48,13 @@ for namespace in MODULE_NAMESPACES:
 class modules(_TestParametrizer):
     """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
 
-    def __init__(self, module_info_list, allowed_dtypes=None):
+    def __init__(self, module_info_list):
         self.module_info_list = module_info_list
-        self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
 
     def _parametrize_test(self, test, generic_cls, device_cls):
         for module_info in self.module_info_list:
             # TODO: Factor some of this out since it's similar to OpInfo.
-            dtypes = set(module_info.dtypes)
-            if self.allowed_dtypes is not None:
-                dtypes = dtypes.intersection(self.allowed_dtypes)
-
-            for dtype in dtypes:
+            for dtype in floating_types():
                 # Construct the test name.
                 test_name = '{}_{}_{}{}'.format(test.__name__,
                                                 module_info.name.replace('.', '_'),
@@ -146,15 +140,11 @@ class ModuleInfo(object):
                  module_inputs_func,  # Function to generate module inputs
                  skips=(),  # Indicates which tests to skip
                  decorators=None,  # Additional decorators to apply to generated tests
-                 dtypes=floating_types(),  # dtypes this function is expected to work with
-                 supports_gradgrad=True,  # whether the op supports second order gradients
                  ):
         self.module_cls = module_cls
         self.module_inputs_func = module_inputs_func
         self.skips = skips
         self.decorators = decorators
-        self.dtypes = dtypes
-        self.supports_gradgrad = supports_gradgrad
 
     def should_skip(self, cls_name, test_name, device_type, dtype):
         return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
@@ -169,8 +159,8 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k
 
     module_inputs = [
         ModuleInput(constructor_input=FunctionInput(10, 8),
-                    forward_input=FunctionInput(input=make_input((4, 10))),
-                    reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
+                    forward_input=FunctionInput(make_input((4, 10))),
+                    reference_fn=lambda m, p, i: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
         ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
                     forward_input=FunctionInput(make_input((4, 10))),
                     desc='no_bias',
@@ -186,20 +176,14 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k
 
 def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs):
     make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
-    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
 
     cases: List[Tuple[str, dict]] = [
         ('', {}),
         ('ignore_index', {'ignore_index': 2}),
-        ('weights', {'weight': make_weight(10).abs()}),
-        ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}),
-        ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1})
+        ('weights', {'weight': make_input(10)}),
+        ('weights_ignore_index', {'weight': make_input(10), 'ignore_index': 2}),
+        ('weights_ignore_index_neg', {'weight': make_input(10), 'ignore_index': -1})
     ]
-
-    # TODO: Uncomment when negative weights is supported.
-    # negative_weight = make_weight(10)
-    # negative_weight[0] = -1
-    # cases.append(('weights_negative', {'weight': negative_weight}))
     module_inputs = []
     for desc, constructor_kwargs in cases:
 
@@ -318,55 +302,6 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k
                     desc='scalar')] + generate_regression_criterion_inputs(make_input)
 
 
-def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs):
-    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
-
-    return [
-        ModuleInput(
-            constructor_input=FunctionInput(),
-            forward_input=FunctionInput(make_input(shape=4)),
-            reference_fn=no_batch_dim_reference_fn,
-            desc='no_batch_dim',
-        )
-    ]
-
-
-def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs):
-    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
-
-    return [
-        ModuleInput(
-            constructor_input=FunctionInput(4, 2, 16, 0.0),
-            forward_input=FunctionInput(
-                make_input(shape=(2, 3, 4))
-            ),
-            desc='relu_activation'
-        ),
-        ModuleInput(
-            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
-            forward_input=FunctionInput(
-                make_input(shape=(2, 3, 4))
-            ),
-            desc='gelu_activation'
-        ),
-    ]
-
-
-def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs):
-    make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
-    return [
-        ModuleInput(
-            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
-            forward_input=FunctionInput(make_empty(2, 3).random_(4))
-        ),
-        ModuleInput(
-            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
-            forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
-            desc='discontiguous'
-        ),
-    ]
-
-
 # Database of ModuleInfo entries in alphabetical order.
 module_db: List[ModuleInfo] = [
     ModuleInfo(torch.nn.AvgPool1d,
@@ -379,14 +314,6 @@ module_db: List[ModuleInfo] = [
                module_inputs_func=module_inputs_torch_nn_Linear),
     ModuleInfo(torch.nn.NLLLoss,
                module_inputs_func=module_inputs_torch_nn_NLLLoss),
-    ModuleInfo(torch.nn.Hardswish,
-               module_inputs_func=module_inputs_torch_nn_Hardswish,
-               supports_gradgrad=False),
-    ModuleInfo(torch.nn.TransformerEncoderLayer,
-               module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
-               supports_gradgrad=False),
-    ModuleInfo(torch.nn.Embedding,
-               module_inputs_func=module_inputs_torch_nn_Embedding),
     ModuleInfo(torch.nn.ReLU,
                module_inputs_func=module_inputs_torch_nn_ReLU),
 ]