from torch.testing import floating_types
from torch.testing._internal.common_device_type import (
_TestParametrizer, _dtype_test_suffix, _update_param_kwargs, skipIf)
-from torch.testing._internal.common_nn import nllloss_reference
-from torch.testing._internal.common_utils import make_tensor
+from torch.testing._internal.common_nn import nllloss_reference, get_reduction
+from torch.testing._internal.common_utils import make_tensor, freeze_rng_state
from types import ModuleType
from typing import List, Tuple, Type, Set, Dict
class modules(_TestParametrizer):
""" PROTOTYPE: Decorator for specifying a list of modules over which to run a test. """
+
def __init__(self, module_info_list):
self.module_info_list = module_info_list
return module_inputs
+def no_batch_dim_reference_fn(m, p, *args, **kwargs):
+ """Reference function for modules supporting no batch dimensions.
+
+ The module is passed the input and target in batched form with a single item.
+ The output is squeezed to compare with the no-batch input.
+ """
+ single_batch_input_args = [input.unsqueeze(0) for input in args]
+ with freeze_rng_state():
+ return m(*single_batch_input_args).squeeze(0)
+
+
+def no_batch_dim_reference_criterion_fn(m, *args, **kwargs):
+ """Reference function for criterion supporting no batch dimensions."""
+ output = no_batch_dim_reference_fn(m, *args, **kwargs)
+ reduction = get_reduction(m)
+ if reduction == 'none':
+ return output.squeeze(0)
+ # reduction is 'sum' or 'mean' which results in a 0D tensor
+ return output
+
+
+def generate_regression_criterion_inputs(make_input):
+ return [
+ ModuleInput(
+ constructor_input=FunctionInput(reduction=reduction),
+ forward_input=FunctionInput(make_input(size=(4, )), make_input(size=4,)),
+ reference_fn=no_batch_dim_reference_criterion_fn,
+ desc='no_batch_dim_{}'.format(reduction)
+ ) for reduction in ['none', 'mean', 'sum']]
+
+
+def module_inputs_torch_nn_AvgPool1d(module_info, device, dtype, requires_grad, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ return [
+ ModuleInput(constructor_input=FunctionInput(kernel_size=2),
+ forward_input=FunctionInput(make_input(size=(3, 6))),
+ desc='no_batch_dim',
+ reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_ELU(module_info, device, dtype, requires_grad, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ return [
+ ModuleInput(constructor_input=FunctionInput(alpha=2.),
+ forward_input=FunctionInput(make_input(size=(3, 2, 5))),
+ reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1))),
+ ModuleInput(constructor_input=FunctionInput(alpha=2.),
+ forward_input=FunctionInput(make_input(size=())),
+ desc='scalar'),
+ ModuleInput(constructor_input=FunctionInput(),
+ forward_input=FunctionInput(make_input(size=(3,))),
+ desc='no_batch_dim',
+ reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ return [
+ ModuleInput(constructor_input=FunctionInput(alpha=2.),
+ forward_input=FunctionInput(make_input(size=(3, 2, 5))),
+ reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2. * ((.5 * i).exp() - 1))),
+ ModuleInput(constructor_input=FunctionInput(alpha=2.),
+ forward_input=FunctionInput(make_input(size=())),
+ reference_fn=lambda m, p, i: torch.where(i >= 0, i, 2 * (i.exp() - 1)),
+ desc='scalar'),
+ ModuleInput(constructor_input=FunctionInput(alpha=2.),
+ forward_input=FunctionInput(make_input(size=(3,))),
+ desc='no_batch_dim',
+ reference_fn=no_batch_dim_reference_fn)]
+
+
+def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ return [
+ ModuleInput(constructor_input=FunctionInput(),
+ forward_input=FunctionInput(make_input(size=(2, 3, 4)),
+ make_input(size=(2, 3, 4))),
+ reference_fn=lambda m, p, i, t: 1. / i.numel() * sum((a - b).abs().sum()
+ for a, b in zip(i, t))),
+ ModuleInput(constructor_input=FunctionInput(),
+ forward_input=FunctionInput(make_input(size=()), make_input(size=())),
+ reference_fn=lambda m, p, i, t: 1. / i.numel() * (i - t).abs().sum(),
+ desc='scalar')] + generate_regression_criterion_inputs(make_input)
+
+
# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
+ ModuleInfo(torch.nn.AvgPool1d,
+ module_inputs_func=module_inputs_torch_nn_AvgPool1d),
+ ModuleInfo(torch.nn.ELU,
+ module_inputs_func=module_inputs_torch_nn_ELU),
+ ModuleInfo(torch.nn.L1Loss,
+ module_inputs_func=module_inputs_torch_nn_L1Loss),
ModuleInfo(torch.nn.Linear,
module_inputs_func=module_inputs_torch_nn_Linear),
ModuleInfo(torch.nn.NLLLoss,