TST Adds more modules into common module tests (#62999)
authorThomas J. Fan <thomasjpfan@gmail.com>
Wed, 25 Aug 2021 02:03:07 +0000 (19:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 02:16:32 +0000 (19:16 -0700)
Summary:
This PR moves some modules into `common_modules` to see what it looks like.

While migrating some no batch modules into `common_modules`, I noticed that `desc` is not used for the name. This means we can not use `-k` to filter tests. This PR moves the sample generation into `_parametrize_test`, and passes in the already generated `module_input` into users of `modules(modules_db)`.

I can see this is a little different from opsinfo and would be happy to revert to the original implementation of `modules`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62999

Reviewed By: heitorschueroff

Differential Revision: D30522737

Pulled By: jbschlosser

fbshipit-source-id: 7ed1aeb3753fc97a4ad6f1a3c789727c78e1bc73

torch/testing/_internal/common_modules.py

index 088e66f..99525a7 100644 (file)
@@ -5,8 +5,8 @@ from itertools import chain
 from torch.testing import floating_types
 from torch.testing._internal.common_device_type import (
     _TestParametrizer, _dtype_test_suffix, _update_param_kwargs, skipIf)
-from torch.testing._internal.common_nn import nllloss_reference
-from torch.testing._internal.common_utils import make_tensor
+from torch.testing._internal.common_nn import nllloss_reference, get_reduction
+from torch.testing._internal.common_utils import make_tensor, freeze_rng_state
 from types import ModuleType
 from typing import List, Tuple, Type, Set, Dict
 
@@ -46,6 +46,7 @@ for namespace in MODULE_NAMESPACES:
 
 class modules(_TestParametrizer):
     """ PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
+
     def __init__(self, module_info_list):
         self.module_info_list = module_info_list
 
@@ -199,8 +200,103 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **
     return module_inputs
 
 
+def no_batch_dim_reference_fn(m, p, *args, **kwargs):
+    """Reference function for modules supporting no batch dimensions.
+
+    The module is passed the input and target in batched form with a single item.
+    The output is squeezed to compare with the no-batch input.
+    """
+    single_batch_input_args = [input.unsqueeze(0) for input in args]
+    with freeze_rng_state():
+        return m(*single_batch_input_args).squeeze(0)
+
+
+def no_batch_dim_reference_criterion_fn(m, *args, **kwargs):
+    """Reference function for criterion supporting no batch dimensions."""
+    output = no_batch_dim_reference_fn(m, *args, **kwargs)
+    reduction = get_reduction(m)
+    if reduction == 'none':
+        return output.squeeze(0)
+    # reduction is 'sum' or 'mean' which results in a 0D tensor
+    return output
+
+
+def generate_regression_criterion_inputs(make_input):
+    return [
+        ModuleInput(
+            constructor_input=FunctionInput(reduction=reduction),
+            forward_input=FunctionInput(make_input(size=(4, )), make_input(size=4,)),
+            reference_fn=no_batch_dim_reference_criterion_fn,
+            desc='no_batch_dim_{}'.format(reduction)
+        ) for reduction in ['none', 'mean', 'sum']]
+
+
+def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(kernel_size=2),
+                    forward_input=FunctionInput(make_input(size=(3, 6))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(size=(3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(size=())),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(size=(3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(size=(3, 2, 5))),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(size=())),
+                    reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)),
+                    desc='scalar'),
+        ModuleInput(constructor_input=FunctionInput(alpha=2.),
+                    forward_input=FunctionInput(make_input(size=(3,))),
+                    desc='no_batch_dim',
+                    reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    return [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(size=(2, 3, 4)),
+                                                make_input(size=(2, 3, 4))),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
+                                                                         for a, b in zip(i, t))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(size=()), make_input(size=())),
+                    reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
+                    desc='scalar')] + generate_regression_criterion_inputs(make_input)
+
+
 # Database of ModuleInfo entries in alphabetical order.
 module_db: List[ModuleInfo] = [
+    ModuleInfo(torch.nn.AvgPool1d,
+               module_inputs_func=module_inputs_torch_nn_AvgPool1d),
+    ModuleInfo(torch.nn.ELU,
+               module_inputs_func=module_inputs_torch_nn_ELU),
+    ModuleInfo(torch.nn.L1Loss,
+               module_inputs_func=module_inputs_torch_nn_L1Loss),
     ModuleInfo(torch.nn.Linear,
                module_inputs_func=module_inputs_torch_nn_Linear),
     ModuleInfo(torch.nn.NLLLoss,