From: Thomas J. Fan Date: Fri, 10 Sep 2021 23:25:21 +0000 (-0700) Subject: TST Adds gradcheck and gradgradcheck to module info (#64444) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~291 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=67ebde56459557199b3c907b81b3c819f77500b9;p=platform%2Fupstream%2Fpytorch.git TST Adds gradcheck and gradgradcheck to module info (#64444) Summary: Follow up to https://github.com/pytorch/pytorch/issues/61935 cc albanD mruberry jbschlosser walterddr Pull Request resolved: https://github.com/pytorch/pytorch/pull/64444 Reviewed By: ngimel Differential Revision: D30867266 Pulled By: jbschlosser fbshipit-source-id: cbc0733261517dbfcdd3415d969b9e802b62b7ac --- diff --git a/test/test_modules.py b/test/test_modules.py index 37ab347..3effcf2 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -6,7 +6,7 @@ import torch from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_utils import ( - TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from) + TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck) from unittest.mock import patch @@ -206,6 +206,58 @@ class TestModule(TestCase): self.assertEqual(input_args[0].grad, input_arg_copy[0].grad) + def _test_gradients_helper(self, device, dtype, module_info, check): + # Check gradients + module_cls = module_info.module_cls + module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, + requires_grad=True) + + for module_input in module_inputs: + if module_input.forward_input is None: + continue + + # === Instantiate the module. === + args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs + m = module_cls(*args, **kwargs) + m.to(device).to(dtype) + + params = tuple(m.parameters()) + + # === Perform gradient check on the input_args === + input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs + + other_kwargs = {} + kwarg_tensors = [] + for name, obj in input_kwargs.items(): + if isinstance(obj, torch.Tensor): + kwarg_tensors.append((name, obj)) + else: + other_kwargs[name] = obj + + grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors) + + def fn_to_gradcheck(*input_and_params): + new_input_args = input_and_params[:len(input_args)] + kwarg_args = input_and_params[-len(kwarg_tensors):] + new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)} + + with freeze_rng_state(): + return m(*new_input_args, **new_kwargs, **other_kwargs) + + self.assertTrue(check(fn_to_gradcheck, grad_input)) + + + @modules(module_db, allowed_dtypes=[torch.double]) + def test_grad(self, device, dtype, module_info): + self._test_gradients_helper(device, dtype, module_info, gradcheck) + + @modules(module_db, allowed_dtypes=[torch.double]) + def test_gradgrad(self, device, dtype, module_info): + if not module_info.supports_gradgrad: + self.skipTest("Skipped! Module does not support gradgrad") + self._test_gradients_helper(device, dtype, module_info, gradgradcheck) + + instantiate_device_type_tests(TestModule, globals()) if __name__ == '__main__': diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index b1cbbb3..0a36309 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -2,6 +2,7 @@ import torch from copy import deepcopy from functools import wraps, partial from itertools import chain +import torch.nn.functional as F from torch.testing import make_tensor from torch.testing._internal.common_dtype import floating_types from torch.testing._internal.common_device_type import ( @@ -48,13 +49,18 @@ for namespace in MODULE_NAMESPACES: class modules(_TestParametrizer): """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """ - def __init__(self, module_info_list): + def __init__(self, module_info_list, allowed_dtypes=None): self.module_info_list = module_info_list + self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None def _parametrize_test(self, test, generic_cls, device_cls): for module_info in self.module_info_list: # TODO: Factor some of this out since it's similar to OpInfo. - for dtype in floating_types(): + dtypes = set(module_info.dtypes) + if self.allowed_dtypes is not None: + dtypes = dtypes.intersection(self.allowed_dtypes) + + for dtype in dtypes: # Construct the test name. test_name = '{}_{}_{}{}'.format(test.__name__, module_info.name.replace('.', '_'), @@ -140,11 +146,15 @@ class ModuleInfo(object): module_inputs_func, # Function to generate module inputs skips=(), # Indicates which tests to skip decorators=None, # Additional decorators to apply to generated tests + dtypes=floating_types(), # dtypes this function is expected to work with + supports_gradgrad=True, # whether the op supports second order gradients ): self.module_cls = module_cls self.module_inputs_func = module_inputs_func self.skips = skips self.decorators = decorators + self.dtypes = dtypes + self.supports_gradgrad = supports_gradgrad def should_skip(self, cls_name, test_name, device_type, dtype): return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips) @@ -159,8 +169,8 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k module_inputs = [ ModuleInput(constructor_input=FunctionInput(10, 8), - forward_input=FunctionInput(make_input((4, 10))), - reference_fn=lambda m, p, i: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)), + forward_input=FunctionInput(input=make_input((4, 10))), + reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)), ModuleInput(constructor_input=FunctionInput(10, 8, bias=False), forward_input=FunctionInput(make_input((4, 10))), desc='no_bias', @@ -176,14 +186,20 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) cases: List[Tuple[str, dict]] = [ ('', {}), ('ignore_index', {'ignore_index': 2}), - ('weights', {'weight': make_input(10)}), - ('weights_ignore_index', {'weight': make_input(10), 'ignore_index': 2}), - ('weights_ignore_index_neg', {'weight': make_input(10), 'ignore_index': -1}) + ('weights', {'weight': make_weight(10).abs()}), + ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}), + ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1}) ] + + # TODO: Uncomment when negative weights is supported. + # negative_weight = make_weight(10) + # negative_weight[0] = -1 + # cases.append(('weights_negative', {'weight': negative_weight})) module_inputs = [] for desc, constructor_kwargs in cases: @@ -302,6 +318,55 @@ def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **k desc='scalar')] + generate_regression_criterion_inputs(make_input) +def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(shape=4)), + reference_fn=no_batch_dim_reference_fn, + desc='no_batch_dim', + ) + ] + + +def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput( + constructor_input=FunctionInput(4, 2, 16, 0.0), + forward_input=FunctionInput( + make_input(shape=(2, 3, 4)) + ), + desc='relu_activation' + ), + ModuleInput( + constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu), + forward_input=FunctionInput( + make_input(shape=(2, 3, 4)) + ), + desc='gelu_activation' + ), + ] + + +def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs): + make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False) + return [ + ModuleInput( + constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3), + forward_input=FunctionInput(make_empty(2, 3).random_(4)) + ), + ModuleInput( + constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3), + forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)), + desc='discontiguous' + ), + ] + + # Database of ModuleInfo entries in alphabetical order. module_db: List[ModuleInfo] = [ ModuleInfo(torch.nn.AvgPool1d, @@ -314,6 +379,14 @@ module_db: List[ModuleInfo] = [ module_inputs_func=module_inputs_torch_nn_Linear), ModuleInfo(torch.nn.NLLLoss, module_inputs_func=module_inputs_torch_nn_NLLLoss), + ModuleInfo(torch.nn.Hardswish, + module_inputs_func=module_inputs_torch_nn_Hardswish, + supports_gradgrad=False), + ModuleInfo(torch.nn.TransformerEncoderLayer, + module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer, + supports_gradgrad=False), + ModuleInfo(torch.nn.Embedding, + module_inputs_func=module_inputs_torch_nn_Embedding), ModuleInfo(torch.nn.ReLU, module_inputs_func=module_inputs_torch_nn_ReLU), ]