from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
- TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, gradgradcheck)
+ TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from)
from unittest.mock import patch
self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
- def _test_gradients_helper(self, device, dtype, module_info, check):
- # Check gradients
- module_cls = module_info.module_cls
- module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
- requires_grad=True)
-
- for module_input in module_inputs:
- if module_input.forward_input is None:
- continue
-
- # === Instantiate the module. ===
- args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
- m = module_cls(*args, **kwargs)
- m.to(device).to(dtype)
-
- params = tuple(m.parameters())
-
- # === Perform gradient check on the input_args ===
- input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
-
- other_kwargs = {}
- kwarg_tensors = []
- for name, obj in input_kwargs.items():
- if isinstance(obj, torch.Tensor):
- kwarg_tensors.append((name, obj))
- else:
- other_kwargs[name] = obj
-
- grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
-
- def fn_to_gradcheck(*input_and_params):
- new_input_args = input_and_params[:len(input_args)]
- kwarg_args = input_and_params[-len(kwarg_tensors):]
- new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
-
- with freeze_rng_state():
- return m(*new_input_args, **new_kwargs, **other_kwargs)
-
- self.assertTrue(check(fn_to_gradcheck, grad_input))
-
-
- @modules(module_db, allowed_dtypes=[torch.double])
- def test_grad(self, device, dtype, module_info):
- self._test_gradients_helper(device, dtype, module_info, gradcheck)
-
- @modules(module_db, allowed_dtypes=[torch.double])
- def test_gradgrad(self, device, dtype, module_info):
- if not module_info.supports_gradgrad:
- self.skipTest("Skipped! Module does not support gradgrad")
- self._test_gradients_helper(device, dtype, module_info, gradgradcheck)
-
-
instantiate_device_type_tests(TestModule, globals())
if __name__ == '__main__':
from copy import deepcopy
from functools import wraps, partial
from itertools import chain
-import torch.nn.functional as F
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import floating_types
from torch.testing._internal.common_device_type import (
class modules(_TestParametrizer):
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
- def __init__(self, module_info_list, allowed_dtypes=None):
+ def __init__(self, module_info_list):
self.module_info_list = module_info_list
- self.allowed_dtypes = set(allowed_dtypes) if allowed_dtypes is not None else None
def _parametrize_test(self, test, generic_cls, device_cls):
for module_info in self.module_info_list:
# TODO: Factor some of this out since it's similar to OpInfo.
- dtypes = set(module_info.dtypes)
- if self.allowed_dtypes is not None:
- dtypes = dtypes.intersection(self.allowed_dtypes)
-
- for dtype in dtypes:
+ for dtype in floating_types():
# Construct the test name.
test_name = '{}_{}_{}{}'.format(test.__name__,
module_info.name.replace('.', '_'),
module_inputs_func, # Function to generate module inputs
skips=(), # Indicates which tests to skip
decorators=None, # Additional decorators to apply to generated tests
- dtypes=floating_types(), # dtypes this function is expected to work with
- supports_gradgrad=True, # whether the op supports second order gradients
):
self.module_cls = module_cls
self.module_inputs_func = module_inputs_func
self.skips = skips
self.decorators = decorators
- self.dtypes = dtypes
- self.supports_gradgrad = supports_gradgrad
def should_skip(self, cls_name, test_name, device_type, dtype):
return any(si.is_active(cls_name, test_name, device_type, dtype) for si in self.skips)
module_inputs = [
ModuleInput(constructor_input=FunctionInput(10, 8),
- forward_input=FunctionInput(input=make_input((4, 10))),
- reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
+ forward_input=FunctionInput(make_input((4, 10))),
+ reference_fn=lambda m, p, i: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
forward_input=FunctionInput(make_input((4, 10))),
desc='no_bias',
def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
- make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[Tuple[str, dict]] = [
('', {}),
('ignore_index', {'ignore_index': 2}),
- ('weights', {'weight': make_weight(10).abs()}),
- ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}),
- ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1})
+ ('weights', {'weight': make_input(10)}),
+ ('weights_ignore_index', {'weight': make_input(10), 'ignore_index': 2}),
+ ('weights_ignore_index_neg', {'weight': make_input(10), 'ignore_index': -1})
]
-
- # TODO: Uncomment when negative weights is supported.
- # negative_weight = make_weight(10)
- # negative_weight[0] = -1
- # cases.append(('weights_negative', {'weight': negative_weight}))
module_inputs = []
for desc, constructor_kwargs in cases:
desc='scalar')] + generate_regression_criterion_inputs(make_input)
-def module_inputs_torch_nn_Hardswish(module_info, device, dtype, requires_grad, **kwargs):
- make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
-
- return [
- ModuleInput(
- constructor_input=FunctionInput(),
- forward_input=FunctionInput(make_input(shape=4)),
- reference_fn=no_batch_dim_reference_fn,
- desc='no_batch_dim',
- )
- ]
-
-
-def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, requires_grad, **kwargs):
- make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
-
- return [
- ModuleInput(
- constructor_input=FunctionInput(4, 2, 16, 0.0),
- forward_input=FunctionInput(
- make_input(shape=(2, 3, 4))
- ),
- desc='relu_activation'
- ),
- ModuleInput(
- constructor_input=FunctionInput(4, 2, 8, 0.0, F.gelu),
- forward_input=FunctionInput(
- make_input(shape=(2, 3, 4))
- ),
- desc='gelu_activation'
- ),
- ]
-
-
-def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, **kwargs):
- make_empty = partial(torch.empty, device=device, dtype=torch.long, requires_grad=False)
- return [
- ModuleInput(
- constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
- forward_input=FunctionInput(make_empty(2, 3).random_(4))
- ),
- ModuleInput(
- constructor_input=FunctionInput(num_embeddings=4, embedding_dim=3),
- forward_input=FunctionInput(make_empty(1, 512).random_(4).expand(7, 512)),
- desc='discontiguous'
- ),
- ]
-
-
# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
ModuleInfo(torch.nn.AvgPool1d,
module_inputs_func=module_inputs_torch_nn_Linear),
ModuleInfo(torch.nn.NLLLoss,
module_inputs_func=module_inputs_torch_nn_NLLLoss),
- ModuleInfo(torch.nn.Hardswish,
- module_inputs_func=module_inputs_torch_nn_Hardswish,
- supports_gradgrad=False),
- ModuleInfo(torch.nn.TransformerEncoderLayer,
- module_inputs_func=module_inputs_torch_nn_TransformerEncoderLayer,
- supports_gradgrad=False),
- ModuleInfo(torch.nn.Embedding,
- module_inputs_func=module_inputs_torch_nn_Embedding),
ModuleInfo(torch.nn.ReLU,
module_inputs_func=module_inputs_torch_nn_ReLU),
]