),
dict(
module_name='Threshold',
- constructor_args=(2, 1),
+ constructor_args=(2., 1.),
input_size=(2, 3, 4, 5),
check_inplace=True,
desc='threshold_value'
),
dict(
module_name='Threshold',
- constructor_args=(2, 10),
+ constructor_args=(2., 10.),
input_size=(2, 3, 4, 5),
desc='large_value'
),
%25 : int[] = prim::ListConstruct(%24, %24), scope: AlexNet/Sequential[features]/Conv2d[0]
%26 : bool = prim::Constant[value=1](), scope: AlexNet/Sequential[features]/Conv2d[0]
%input.1 : Double(1, 64, 55, 55) = aten::_convolution(%0, %1, %2, %18, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[0]
- %input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %24, %24), scope: AlexNet/Sequential[features]/ReLU[1]
- %29 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
- %30 : int[] = prim::ListConstruct(%29, %29), scope: AlexNet/Sequential[features]/MaxPool2d[2]
- %31 : Double(1, 64, 27, 27), %32 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
- %input.3 : Double(1, 192, 27, 27) = aten::_convolution(%31, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
- %input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %24, %24), scope: AlexNet/Sequential[features]/ReLU[4]
- %35 : Double(1, 192, 13, 13), %36 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
- %input.5 : Double(1, 384, 13, 13) = aten::_convolution(%35, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
- %38 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %24, %24), scope: AlexNet/Sequential[features]/ReLU[7]
- %input.6 : Double(1, 256, 13, 13) = aten::_convolution(%38, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
- %40 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %24, %24), scope: AlexNet/Sequential[features]/ReLU[9]
- %input.7 : Double(1, 256, 13, 13) = aten::_convolution(%40, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
- %input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %24, %24), scope: AlexNet/Sequential[features]/ReLU[11]
- %43 : Double(1, 256, 6, 6), %44 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
- %45 : int = aten::size(%43, %24), scope: AlexNet
- %46 : Long() = prim::NumToTensor(%45), scope: AlexNet
- %47 : int = prim::TensorToNum(%46), scope: AlexNet
- %48 : int = prim::Constant[value=9216](), scope: AlexNet
- %49 : int[] = prim::ListConstruct(%47, %48), scope: AlexNet
- %input.9 : Double(1, 9216) = aten::view(%43, %49), scope: AlexNet
- %51 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
- %input.10 : Double(1, 9216) = aten::dropout(%input.9, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
- %53 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
- %input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %53, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
- %input.12 : Double(1, 4096) = aten::threshold_(%input.11, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[2]
- %input.13 : Double(1, 4096) = aten::dropout(%input.12, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
- %57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
- %input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
- %input : Double(1, 4096) = aten::threshold_(%input.14, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[5]
- %60 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
- %61 : Double(1, 1000) = aten::addmm(%16, %input, %60, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
- return (%61);
+ %28 : float = prim::Constant[value=0](), scope: AlexNet/Sequential[features]/ReLU[1]
+ %input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %28, %28), scope: AlexNet/Sequential[features]/ReLU[1]
+ %30 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
+ %31 : int[] = prim::ListConstruct(%30, %30), scope: AlexNet/Sequential[features]/MaxPool2d[2]
+ %32 : Double(1, 64, 27, 27), %33 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
+ %input.3 : Double(1, 192, 27, 27) = aten::_convolution(%32, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
+ %input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %28, %28), scope: AlexNet/Sequential[features]/ReLU[4]
+ %36 : Double(1, 192, 13, 13), %37 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
+ %input.5 : Double(1, 384, 13, 13) = aten::_convolution(%36, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
+ %39 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %28, %28), scope: AlexNet/Sequential[features]/ReLU[7]
+ %input.6 : Double(1, 256, 13, 13) = aten::_convolution(%39, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
+ %41 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %28, %28), scope: AlexNet/Sequential[features]/ReLU[9]
+ %input.7 : Double(1, 256, 13, 13) = aten::_convolution(%41, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
+ %input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %28, %28), scope: AlexNet/Sequential[features]/ReLU[11]
+ %44 : Double(1, 256, 6, 6), %45 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
+ %46 : int = aten::size(%44, %24), scope: AlexNet
+ %47 : Long() = prim::NumToTensor(%46), scope: AlexNet
+ %48 : int = prim::TensorToNum(%47), scope: AlexNet
+ %49 : int = prim::Constant[value=9216](), scope: AlexNet
+ %50 : int[] = prim::ListConstruct(%48, %49), scope: AlexNet
+ %input.9 : Double(1, 9216) = aten::view(%44, %50), scope: AlexNet
+ %52 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
+ %input.10 : Double(1, 9216) = aten::dropout(%input.9, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
+ %54 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
+ %input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %54, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
+ %input.12 : Double(1, 4096) = aten::threshold_(%input.11, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[2]
+ %input.13 : Double(1, 4096) = aten::dropout(%input.12, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
+ %58 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
+ %input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %58, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
+ %input : Double(1, 4096) = aten::threshold_(%input.14, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[5]
+ %61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
+ %62 : Double(1, 1000) = aten::addmm(%16, %input, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
+ return (%62);
}
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
from torch._six import inf, PY2
-from common_utils import (TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN,
- skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE)
+from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
+ skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
+ freeze_rng_state
+from test_nn import module_tests, new_module_tests
from textwrap import dedent
import os
import io
from test_module.future_div import div_int_future, div_float_future
from test_module.no_future_div import div_int_nofuture, div_float_nofuture
+
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
def runAndSaveRNG(self, func, inputs, kwargs=None):
kwargs = kwargs if kwargs else {}
- initial_rng_state = torch.get_rng_state()
- results = func(*inputs, **kwargs)
- torch.set_rng_state(initial_rng_state)
+ with freeze_rng_state():
+ results = func(*inputs, **kwargs)
return results
'test_nn_max_unpool1d',
}
+EXCLUDE_SCRIPT_MODULES = {
+ 'test_nn_LPPool2d_norm',
+ 'test_nn_LPPool1d_norm',
+}
+
DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
'test_nn_avg_pool2d',
'test_nn_log_softmax',
post_add_test(test_name, skipTestIf, do_test)
-def add_nn_module_test(module_name, constructor_args, call_args,
- use_as_constant=False, skipTestIf=()):
+def add_nn_module_test(*args, **kwargs):
+ if 'module_name' in kwargs:
+ name = kwargs['module_name']
+ elif 'fullname' in kwargs:
+ name = kwargs['fullname']
+ elif 'constructor' in kwargs:
+ name = kwargs['constructor'].__name__
+
+ class_name = name.split("_")[0]
+
+ module = getattr(torch.nn, class_name, None)
+ if module is None or torch._jit_internal._weak_types.get(module) is None:
+ return
+
+ if 'desc' in kwargs and 'eval' in kwargs['desc']:
+ # eval() is not supported, so skip these tests
+ return
+
+ test_name = name
+ if 'desc' in kwargs:
+ test_name = "{}_{}".format(test_name, kwargs['desc'])
+ test_name = 'test_nn_{}'.format(test_name)
+
def do_test(self):
- nn_module = getattr(torch.nn, module_name)
+ if test_name in EXCLUDE_SCRIPT_MODULES:
+ return
+ if 'constructor' in kwargs:
+ nn_module = kwargs['constructor']
+ else:
+ nn_module = getattr(torch.nn, name)
+ constructor_args = kwargs.get('constructor_args', ())
# Construct a script module that passes arguments through
# to self.submodule
script = script_method_template.format(method_args, call)
submodule_constants = []
- if use_as_constant:
+ if kwargs.get('is_constant'):
submodule_constants = ['submodule']
# Create module to use the script method
return module(*args)
# Check against Python module as reference
- args_variable, kwargs_variable = create_input(call_args)
+ if 'input_fn' in kwargs:
+ input_size = tuple(kwargs['input_fn']().size())
+ else:
+ input_size = kwargs['input_size']
+ args_variable, kwargs_variable = create_input((input_size,))
f_args_variable = deepcopy(unpack_variables(args_variable))
check_against_reference(self, create_script_module, create_nn_module, f_args_variable)
- test_name = 'test_nn_{}'.format(module_name)
- post_add_test(test_name, skipTestIf, do_test)
+ post_add_test(test_name, (), do_test)
def post_add_test(test_name, skipTestIf, do_test):
for test in nn_functional_tests:
add_nn_functional_test(*test)
-for test in nn_module_tests:
- add_nn_module_test(*test)
+for test in module_tests + new_module_tests:
+ add_nn_module_test(**test)
if __name__ == '__main__':
run_tests()
),
dict(
module_name='LPPool2d',
- constructor_args=(2, (2, 2), 2),
+ constructor_args=(2, 2, 2),
input_size=(1, 3, 7, 7),
),
dict(
),
dict(
module_name='Threshold',
- constructor_args=(2, 1),
+ constructor_args=(2., 1.),
input_size=(),
check_inplace=True,
desc='threshold_value_scalar'
@torch._jit_internal.weak_script
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
- # type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
+ # type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
r"""Applies a 2D power-average pooling over an input signal composed of
several input planes. If the sum of all inputs to the power of `p` is
zero, the gradient is set to zero as well.
@torch._jit_internal.weak_module
class Threshold(Module):
- __constants__ = ['threshold', 'value', 'inplace']
-
r"""Thresholds each element of the input Tensor
Threshold is defined as:
>>> input = torch.randn(2)
>>> output = m(input)
"""
+ __constants__ = ['threshold', 'value', 'inplace']
def __init__(self, threshold, value, inplace=False):
super(Threshold, self).__init__()
)
+@torch._jit_internal.weak_module
class ReLU(Threshold):
r"""Applies the rectified linear unit function element-wise
:math:`\text{ReLU}(x)= \max(0, x)`
"""
def __init__(self, inplace=False):
- super(ReLU, self).__init__(0, 0, inplace)
+ super(ReLU, self).__init__(0., 0., inplace)
def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
+@torch._jit_internal.weak_module
class Hardtanh(Module):
r"""Applies the HardTanh function element-wise
>>> input = torch.randn(2)
>>> output = m(input)
"""
+ __constants__ = ['min_val', 'max_val', 'inplace']
- def __init__(self, min_val=-1, max_val=1, inplace=False, min_value=None, max_value=None):
+ def __init__(self, min_val=-1., max_val=1., inplace=False, min_value=None, max_value=None):
super(Hardtanh, self).__init__()
if min_value is not None:
warnings.warn("keyword argument min_value is deprecated and renamed to min_val")
self.inplace = inplace
assert self.max_val > self.min_val
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
)
+@torch._jit_internal.weak_module
class ReLU6(Hardtanh):
r"""Applies the element-wise function:
"""
def __init__(self, inplace=False):
- super(ReLU6, self).__init__(0, 6, inplace)
+ super(ReLU6, self).__init__(0., 6., inplace)
def extra_repr(self):
inplace_str = 'inplace' if self.inplace else ''
return torch.tanh(input)
+@torch._jit_internal.weak_module
class ELU(Module):
r"""Applies the element-wise function:
>>> input = torch.randn(2)
>>> output = m(input)
"""
+ __constants__ = ['alpha', 'inplace']
def __init__(self, alpha=1., inplace=False):
super(ELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.elu(input, self.alpha, self.inplace)
return 'alpha={}{}'.format(self.alpha, inplace_str)
+@torch._jit_internal.weak_module
class CELU(Module):
r"""Applies the element-wise function:
.. _`Continuously Differentiable Exponential Linear Units`:
https://arxiv.org/abs/1704.07483
"""
+ __constants__ = ['alpha', 'inplace']
def __init__(self, alpha=1., inplace=False):
super(CELU, self).__init__()
self.alpha = alpha
self.inplace = inplace
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.celu(input, self.alpha, self.inplace)
return 'alpha={}{}'.format(self.alpha, inplace_str)
+@torch._jit_internal.weak_module
class SELU(Module):
r"""Applied element-wise, as:
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
"""
+ __constants__ = ['inplace']
def __init__(self, inplace=False):
super(SELU, self).__init__()
self.inplace = inplace
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.selu(input, self.inplace)
return inplace_str
+@torch._jit_internal.weak_module
class GLU(Module):
r"""Applies the gated linear unit function
:math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
>>> input = torch.randn(4, 2)
>>> output = m(input)
"""
+ __constants__ = ['dim']
def __init__(self, dim=-1):
super(GLU, self).__init__()
self.dim = dim
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.glu(input, self.dim)
return '{}'.format(self.lambd)
+@torch._jit_internal.weak_module
class LeakyReLU(Module):
r"""Applies the element-wise function:
>>> input = torch.randn(2)
>>> output = m(input)
"""
+ __constants__ = ['inplace', 'negative_slope']
def __init__(self, negative_slope=1e-2, inplace=False):
super(LeakyReLU, self).__init__()
self.negative_slope = negative_slope
self.inplace = inplace
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.leaky_relu(input, self.negative_slope, self.inplace)
return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
+@torch._jit_internal.weak_module
class LogSigmoid(Module):
r"""Applies the element-wise function:
>>> output = m(input)
"""
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.logsigmoid(input)
+@torch._jit_internal.weak_module
class Softplus(Module):
r"""Applies the element-wise function:
>>> input = torch.randn(2)
>>> output = m(input)
"""
+ __constants__ = ['beta', 'threshold']
def __init__(self, beta=1, threshold=20):
super(Softplus, self).__init__()
self.beta = beta
self.threshold = threshold
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.softplus(input, self.beta, self.threshold)
return F.tanhshrink(input)
+@torch._jit_internal.weak_module
class Softmin(Module):
r"""Applies the Softmin function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
+ __constants__ = ['dim']
def __init__(self, dim=None):
super(Softmin, self).__init__()
self.dim = dim
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.softmin(input, self.dim, _stacklevel=5)
+@torch._jit_internal.weak_module
class Softmax(Module):
r"""Applies the Softmax function to an n-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
+ __constants__ = ['dim']
def __init__(self, dim=None):
super(Softmax, self).__init__()
if not hasattr(self, 'dim'):
self.dim = None
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.softmax(input, self.dim, _stacklevel=5)
+@torch._jit_internal.weak_module
class Softmax2d(Module):
r"""Applies SoftMax over features to each spatial location.
>>> output = m(input)
"""
+ @torch._jit_internal.weak_script_method
def forward(self, input):
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
return F.softmax(input, 1, _stacklevel=5)
+@torch._jit_internal.weak_module
class LogSoftmax(Module):
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
>>> input = torch.randn(2, 3)
>>> output = m(input)
"""
+ __constants__ = ['dim']
def __init__(self, dim=None):
super(LogSoftmax, self).__init__()
if not hasattr(self, 'dim'):
self.dim = None
+ @torch._jit_internal.weak_script_method
def forward(self, input):
return F.log_softmax(input, self.dim, _stacklevel=5)