From ba126df61448ca3442ec77374bc32f43fcdd9773 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 24 Aug 2021 19:03:07 -0700 Subject: [PATCH] TST Adds more modules into common module tests (#62999) Summary: This PR moves some modules into `common_modules` to see what it looks like. While migrating some no batch modules into `common_modules`, I noticed that `desc` is not used for the name. This means we can not use `-k` to filter tests. This PR moves the sample generation into `_parametrize_test`, and passes in the already generated `module_input` into users of `modules(modules_db)`. I can see this is a little different from opsinfo and would be happy to revert to the original implementation of `modules`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/62999 Reviewed By: heitorschueroff Differential Revision: D30522737 Pulled By: jbschlosser fbshipit-source-id: 7ed1aeb3753fc97a4ad6f1a3c789727c78e1bc73 --- torch/testing/_internal/common_modules.py | 100 +++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 088e66f..99525a7 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -5,8 +5,8 @@ from itertools import chain from torch.testing import floating_types from torch.testing._internal.common_device_type import ( _TestParametrizer, _dtype_test_suffix, _update_param_kwargs, skipIf) -from torch.testing._internal.common_nn import nllloss_reference -from torch.testing._internal.common_utils import make_tensor +from torch.testing._internal.common_nn import nllloss_reference, get_reduction +from torch.testing._internal.common_utils import make_tensor, freeze_rng_state from types import ModuleType from typing import List, Tuple, Type, Set, Dict @@ -46,6 +46,7 @@ for namespace in MODULE_NAMESPACES: class modules(_TestParametrizer): """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """ + def __init__(self, module_info_list): self.module_info_list = module_info_list @@ -199,8 +200,103 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, ** return module_inputs +def no_batch_dim_reference_fn(m, p, *args, **kwargs): + """Reference function for modules supporting no batch dimensions. + + The module is passed the input and target in batched form with a single item. + The output is squeezed to compare with the no-batch input. + """ + single_batch_input_args = [input.unsqueeze(0) for input in args] + with freeze_rng_state(): + return m(*single_batch_input_args).squeeze(0) + + +def no_batch_dim_reference_criterion_fn(m, *args, **kwargs): + """Reference function for criterion supporting no batch dimensions.""" + output = no_batch_dim_reference_fn(m, *args, **kwargs) + reduction = get_reduction(m) + if reduction == 'none': + return output.squeeze(0) + # reduction is 'sum' or 'mean' which results in a 0D tensor + return output + + +def generate_regression_criterion_inputs(make_input): + return [ + ModuleInput( + constructor_input=FunctionInput(reduction=reduction), + forward_input=FunctionInput(make_input(size=(4, )), make_input(size=4,)), + reference_fn=no_batch_dim_reference_criterion_fn, + desc='no_batch_dim_{}'.format(reduction) + ) for reduction in ['none', 'mean', 'sum']] + + +def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(kernel_size=2), + forward_input=FunctionInput(make_input(size=(3, 6))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(size=(3, 2, 5))), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(size=())), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(size=(3,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(size=(3, 2, 5))), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(size=())), + reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)), + desc='scalar'), + ModuleInput(constructor_input=FunctionInput(alpha=2.), + forward_input=FunctionInput(make_input(size=(3,))), + desc='no_batch_dim', + reference_fn=no_batch_dim_reference_fn)] + + +def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(size=(2, 3, 4)), + make_input(size=(2, 3, 4))), + reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum() + for a, b in zip(i, t))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(size=()), make_input(size=())), + reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(), + desc='scalar')] + generate_regression_criterion_inputs(make_input) + + # Database of ModuleInfo entries in alphabetical order. module_db: List[ModuleInfo] = [ + ModuleInfo(torch.nn.AvgPool1d, + module_inputs_func=module_inputs_torch_nn_AvgPool1d), + ModuleInfo(torch.nn.ELU, + module_inputs_func=module_inputs_torch_nn_ELU), + ModuleInfo(torch.nn.L1Loss, + module_inputs_func=module_inputs_torch_nn_L1Loss), ModuleInfo(torch.nn.Linear, module_inputs_func=module_inputs_torch_nn_Linear), ModuleInfo(torch.nn.NLLLoss, -- 2.7.4