TST Adds gradcheck and gradgradcheck to module info (#64444)
authorThomas J. Fan <thomasjpfan@gmail.com>
Fri, 10 Sep 2021 23:25:21 +0000 (16:25 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 23:53:11 +0000 (16:53 -0700)
Summary:
Follow up to https://github.com/pytorch/pytorch/issues/61935

cc albanD mruberry jbschlosser walterddr

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64444

Reviewed By: ngimel

Differential Revision: D30867266

Pulled By: jbschlosser

fbshipit-source-id: cbc0733261517dbfcdd3415d969b9e802b62b7ac

test/test_modules.py
torch/testing/_internal/common_modules.py

index 37ab347..3effcf2 100644 (file)
@@ -6,7 +6,7 @@ import torch
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_modules import module_db, modules
 from torch.testing._internal.common_utils import (
-    TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from)
+    TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
 from unittest.mock import patch
 
 
@@ -206,6 +206,58 @@ class TestModule(TestCase):
             self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
 
 
+    def _test_gradients_helper(self, device, dtype, module_info, check):
+        # Check gradients
+        module_cls = module_info.module_cls
+        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
+                                                       requires_grad=True)
+
+        for module_input in module_inputs:
+            if module_input.forward_input is None:
+                continue
+
+            # === Instantiate the module. ===
+            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
+            m = module_cls(*args, **kwargs)
+            m.to(device).to(dtype)
+
+            params = tuple(m.parameters())
+
+            # === Perform gradient check on the input_args ===
+            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
+
+            other_kwargs = {}
+            kwarg_tensors = []
+            for name, obj in input_kwargs.items():
+                if isinstance(obj, torch.Tensor):
+                    kwarg_tensors.append((name, obj))
+                else:
+                    other_kwargs[name] = obj
+
+            grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
+
+            def fn_to_gradcheck(*input_and_params):
+                new_input_args = input_and_params[:len(input_args)]
+                kwarg_args = input_and_params[-len(kwarg_tensors):]
+                new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
+
+                with freeze_rng_state():
+                    return m(*new_input_args, **new_kwargs, **other_kwargs)
+
+            self.assertTrue(check(fn_to_gradcheck, grad_input))
+
+
+    @modules(module_db, allowed_dtypes=[torch.double])
+    def test_grad(self, device, dtype, module_info):
+        self._test_gradients_helper(device, dtype, module_info, gradcheck)
+
+    @modules(module_db, allowed_dtypes=[torch.double])
+    def test_gradgrad(self, device, dtype, module_info):
+        if not module_info.supports_gradgrad:
+            self.skipTest("Skipped! Module does not support gradgrad")
+        self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
+
+
 instantiate_device_type_tests(TestModule, globals())
 
 if __name__ == '__main__':
index b1cbbb3..0a36309 100644 (file)
@@ -2,6 +2,7 @@ import torch
 from copy import deepcopy
 from functools import wraps, partial
 from itertools import chain
+import torch.nn.functional as F
 from torch.testing import make_tensor
 from torch.testing._internal.common_dtype import floating_types
 from torch.testing._internal.common_device_type import (
@@ -48,13 +49,18 @@ for namespace in MODULE_NAMESPACES:
 class modules(_TestParametrizer):
     """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
 
-    def __init__(self, module_info_list):
+    def __init__(self, module_info_list, allowed_dtypes=None):
         self.module_info_list = module_info_list
+        self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
 
     def _parametrize_test(self, test, generic_cls, device_cls):
         for module_info in self.module_info_list:
             # TODO: Factor some of this out since it's similar to OpInfo.
-            for dtype in floating_types():
+            dtypes = set(module_info.dtypes)
+            if self.allowed_dtypes is not None:
+                dtypes = dtypes.intersection(self.allowed_dtypes)
+
+            for dtype in dtypes:
                 # Construct the test name.
                 test_name = '{}_{}_{}{}'.format(test.__name__,
                                                 module_info.name.replace('.', '_'),
@@ -140,11 +146,15 @@ class ModuleInfo(object):
                  module_inputs_func,  # Function to generate module inputs
                  skips=(),  # Indicates which tests to skip
                  decorators=None,  # Additional decorators to apply to generated tests
+                 dtypes=floating_types(),  # dtypes this function is expected to work with
+                 supports_gradgrad=True,  # whether the op supports second order gradients
                  ):
         self.module_cls = module_cls
         self.module_inputs_func = module_inputs_func
         self.skips = skips
         self.decorators = decorators
+        self.dtypes = dtypes
+        self.supports_gradgrad = supports_gradgrad
 
     def should_skip(self, cls_name, test_name, device_type, dtype):
         return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
@@ -159,8 +169,8 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k
 
     module_inputs = [
         ModuleInput(constructor_input=FunctionInput(10, 8),
-                    forward_input=FunctionInput(make_input((4, 10))),
-                    reference_fn=lambda m, p, i: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
+                    forward_input=FunctionInput(input=make_input((4, 10))),
+                    reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
         ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
                     forward_input=FunctionInput(make_input((4, 10))),
                     desc='no_bias',
@@ -176,14 +186,20 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k
 
 def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs):
     make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
 
     cases: List[Tuple[str, dict]] = [
         ('', {}),
         ('ignore_index', {'ignore_index': 2}),
-        ('weights', {'weight': make_input(10)}),
-        ('weights_ignore_index', {'weight': make_input(10), 'ignore_index': 2}),
-        ('weights_ignore_index_neg', {'weight': make_input(10), 'ignore_index': -1})
+        ('weights', {'weight': make_weight(10).abs()}),
+        ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}),
+        ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1})
     ]
+
+    # TODO: Uncomment when negative weights is supported.
+    # negative_weight = make_weight(10)
+    # negative_weight[0] = -1
+    # cases.append(('weights_negative', {'weight': negative_weight}))
     module_inputs = []
     for desc, constructor_kwargs in cases:
 
@@ -302,6 +318,55 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k
                     desc='scalar')] + generate_regression_criterion_inputs(make_input)
 
 
+def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(),
+            forward_input=FunctionInput(make_input(shape=4)),
+            reference_fn=no_batch_dim_reference_fn,
+            desc='no_batch_dim',
+        )
+    ]
+
+
+def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 16, 0.0),
+            forward_input=FunctionInput(
+                make_input(shape=(2, 3, 4))
+            ),
+            desc='relu_activation'
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+            forward_input=FunctionInput(
+                make_input(shape=(2, 3, 4))
+            ),
+            desc='gelu_activation'
+        ),
+    ]
+
+
+def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs):
+    make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(2, 3).random_(4))
+        ),
+        ModuleInput(
+            constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+            forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
+            desc='discontiguous'
+        ),
+    ]
+
+
 # Database of ModuleInfo entries in alphabetical order.
 module_db: List[ModuleInfo] = [
     ModuleInfo(torch.nn.AvgPool1d,
@@ -314,6 +379,14 @@ module_db: List[ModuleInfo] = [
                module_inputs_func=module_inputs_torch_nn_Linear),
     ModuleInfo(torch.nn.NLLLoss,
                module_inputs_func=module_inputs_torch_nn_NLLLoss),
+    ModuleInfo(torch.nn.Hardswish,
+               module_inputs_func=module_inputs_torch_nn_Hardswish,
+               supports_gradgrad=False),
+    ModuleInfo(torch.nn.TransformerEncoderLayer,
+               module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
+               supports_gradgrad=False),
+    ModuleInfo(torch.nn.Embedding,
+               module_inputs_func=module_inputs_torch_nn_Embedding),
     ModuleInfo(torch.nn.ReLU,
                module_inputs_func=module_inputs_torch_nn_ReLU),
 ]