from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
- TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from)
+ TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
from unittest.mock import patch
self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
+ def _test_gradients_helper(self, device, dtype, module_info, check):
+ # Check gradients
+ module_cls = module_info.module_cls
+ module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
+ requires_grad=True)
+
+ for module_input in module_inputs:
+ if module_input.forward_input is None:
+ continue
+
+ # === Instantiate the module. ===
+ args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
+ m = module_cls(*args, **kwargs)
+ m.to(device).to(dtype)
+
+ params = tuple(m.parameters())
+
+ # === Perform gradient check on the input_args ===
+ input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
+
+ other_kwargs = {}
+ kwarg_tensors = []
+ for name, obj in input_kwargs.items():
+ if isinstance(obj, torch.Tensor):
+ kwarg_tensors.append((name, obj))
+ else:
+ other_kwargs[name] = obj
+
+ grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
+
+ def fn_to_gradcheck(*input_and_params):
+ new_input_args = input_and_params[:len(input_args)]
+ kwarg_args = input_and_params[-len(kwarg_tensors):]
+ new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
+
+ with freeze_rng_state():
+ return m(*new_input_args, **new_kwargs, **other_kwargs)
+
+ self.assertTrue(check(fn_to_gradcheck, grad_input))
+
+
+ @modules(module_db, allowed_dtypes=[torch.double])
+ def test_grad(self, device, dtype, module_info):
+ self._test_gradients_helper(device, dtype, module_info, gradcheck)
+
+ @modules(module_db, allowed_dtypes=[torch.double])
+ def test_gradgrad(self, device, dtype, module_info):
+ if not module_info.supports_gradgrad:
+ self.skipTest("Skipped! Module does not support gradgrad")
+ self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
+
+
instantiate_device_type_tests(TestModule, globals())
if __name__ == '__main__':
from copy import deepcopy
from functools import wraps, partial
from itertools import chain
+import torch.nn.functional as F
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import floating_types
from torch.testing._internal.common_device_type import (
class modules(_TestParametrizer):
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
- def __init__(self, module_info_list):
+ def __init__(self, module_info_list, allowed_dtypes=None):
self.module_info_list = module_info_list
+ self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
def _parametrize_test(self, test, generic_cls, device_cls):
for module_info in self.module_info_list:
# TODO: Factor some of this out since it's similar to OpInfo.
- for dtype in floating_types():
+ dtypes = set(module_info.dtypes)
+ if self.allowed_dtypes is not None:
+ dtypes = dtypes.intersection(self.allowed_dtypes)
+
+ for dtype in dtypes:
# Construct the test name.
test_name = '{}_{}_{}{}'.format(test.__name__,
module_info.name.replace('.', '_'),
module_inputs_func, # Function to generate module inputs
skips=(), # Indicates which tests to skip
decorators=None, # Additional decorators to apply to generated tests
+ dtypes=floating_types(), # dtypes this function is expected to work with
+ supports_gradgrad=True, # whether the op supports second order gradients
):
self.module_cls = module_cls
self.module_inputs_func = module_inputs_func
self.skips = skips
self.decorators = decorators
+ self.dtypes = dtypes
+ self.supports_gradgrad = supports_gradgrad
def should_skip(self, cls_name, test_name, device_type, dtype):
return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
module_inputs = [
ModuleInput(constructor_input=FunctionInput(10, 8),
- forward_input=FunctionInput(make_input((4, 10))),
- reference_fn=lambda m, p, i: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
+ forward_input=FunctionInput(input=make_input((4, 10))),
+ reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
forward_input=FunctionInput(make_input((4, 10))),
desc='no_bias',
def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+ make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[Tuple[str, dict]] = [
('', {}),
('ignore_index', {'ignore_index': 2}),
- ('weights', {'weight': make_input(10)}),
- ('weights_ignore_index', {'weight': make_input(10), 'ignore_index': 2}),
- ('weights_ignore_index_neg', {'weight': make_input(10), 'ignore_index': -1})
+ ('weights', {'weight': make_weight(10).abs()}),
+ ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}),
+ ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1})
]
+
+ # TODO: Uncomment when negative weights is supported.
+ # negative_weight = make_weight(10)
+ # negative_weight[0] = -1
+ # cases.append(('weights_negative', {'weight': negative_weight}))
module_inputs = []
for desc, constructor_kwargs in cases:
desc='scalar')] + generate_regression_criterion_inputs(make_input)
+def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ return [
+ ModuleInput(
+ constructor_input=FunctionInput(),
+ forward_input=FunctionInput(make_input(shape=4)),
+ reference_fn=no_batch_dim_reference_fn,
+ desc='no_batch_dim',
+ )
+ ]
+
+
+def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ return [
+ ModuleInput(
+ constructor_input=FunctionInput(4, 2, 16, 0.0),
+ forward_input=FunctionInput(
+ make_input(shape=(2, 3, 4))
+ ),
+ desc='relu_activation'
+ ),
+ ModuleInput(
+ constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
+ forward_input=FunctionInput(
+ make_input(shape=(2, 3, 4))
+ ),
+ desc='gelu_activation'
+ ),
+ ]
+
+
+def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs):
+ make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
+ return [
+ ModuleInput(
+ constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+ forward_input=FunctionInput(make_empty(2, 3).random_(4))
+ ),
+ ModuleInput(
+ constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
+ forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
+ desc='discontiguous'
+ ),
+ ]
+
+
# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
ModuleInfo(torch.nn.AvgPool1d,
module_inputs_func=module_inputs_torch_nn_Linear),
ModuleInfo(torch.nn.NLLLoss,
module_inputs_func=module_inputs_torch_nn_NLLLoss),
+ ModuleInfo(torch.nn.Hardswish,
+ module_inputs_func=module_inputs_torch_nn_Hardswish,
+ supports_gradgrad=False),
+ ModuleInfo(torch.nn.TransformerEncoderLayer,
+ module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
+ supports_gradgrad=False),
+ ModuleInfo(torch.nn.Embedding,
+ module_inputs_func=module_inputs_torch_nn_Embedding),
ModuleInfo(torch.nn.ReLU,
module_inputs_func=module_inputs_torch_nn_ReLU),
]