Use nn module tests in test_jit (#14238)
authorDavid Riazati <davidriazati@fb.com>
Wed, 28 Nov 2018 05:17:51 +0000 (21:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 05:19:51 +0000 (21:19 -0800)
Summary:
This PR adds weak modules for all activation modules and uses `test_nn` module tests to test weak modules that have been annotated with `weak_module` and therefore are in `torch._jit_internal._weak_types`

Also depends on #14379
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14238

Differential Revision: D13192230

Pulled By: driazati

fbshipit-source-id: 36488960b6c91448b38c0fa65422539a93af8c5e

test/common_nn.py
test/expect/TestJit.test_alexnet.expect
test/test_jit.py
test/test_nn.py
torch/nn/functional.py
torch/nn/modules/activation.py

index c56ee4d..27f1571 100644 (file)
@@ -51,14 +51,14 @@ module_tests = [
     ),
     dict(
         module_name='Threshold',
-        constructor_args=(2, 1),
+        constructor_args=(2., 1.),
         input_size=(2, 3, 4, 5),
         check_inplace=True,
         desc='threshold_value'
     ),
     dict(
         module_name='Threshold',
-        constructor_args=(2, 10),
+        constructor_args=(2., 10.),
         input_size=(2, 3, 4, 5),
         desc='large_value'
     ),
index bd9e335..0652078 100644 (file)
@@ -26,36 +26,37 @@ graph(%0 : Double(1, 3, 224, 224)
   %25 : int[] = prim::ListConstruct(%24, %24), scope: AlexNet/Sequential[features]/Conv2d[0]
   %26 : bool = prim::Constant[value=1](), scope: AlexNet/Sequential[features]/Conv2d[0]
   %input.1 : Double(1, 64, 55, 55) = aten::_convolution(%0, %1, %2, %18, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[0]
-  %input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %24, %24), scope: AlexNet/Sequential[features]/ReLU[1]
-  %29 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
-  %30 : int[] = prim::ListConstruct(%29, %29), scope: AlexNet/Sequential[features]/MaxPool2d[2]
-  %31 : Double(1, 64, 27, 27), %32 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
-  %input.3 : Double(1, 192, 27, 27) = aten::_convolution(%31, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
-  %input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %24, %24), scope: AlexNet/Sequential[features]/ReLU[4]
-  %35 : Double(1, 192, 13, 13), %36 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
-  %input.5 : Double(1, 384, 13, 13) = aten::_convolution(%35, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
-  %38 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %24, %24), scope: AlexNet/Sequential[features]/ReLU[7]
-  %input.6 : Double(1, 256, 13, 13) = aten::_convolution(%38, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
-  %40 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %24, %24), scope: AlexNet/Sequential[features]/ReLU[9]
-  %input.7 : Double(1, 256, 13, 13) = aten::_convolution(%40, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
-  %input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %24, %24), scope: AlexNet/Sequential[features]/ReLU[11]
-  %43 : Double(1, 256, 6, 6), %44 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
-  %45 : int = aten::size(%43, %24), scope: AlexNet
-  %46 : Long() = prim::NumToTensor(%45), scope: AlexNet
-  %47 : int = prim::TensorToNum(%46), scope: AlexNet
-  %48 : int = prim::Constant[value=9216](), scope: AlexNet
-  %49 : int[] = prim::ListConstruct(%47, %48), scope: AlexNet
-  %input.9 : Double(1, 9216) = aten::view(%43, %49), scope: AlexNet
-  %51 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
-  %input.10 : Double(1, 9216) = aten::dropout(%input.9, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
-  %53 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
-  %input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %53, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
-  %input.12 : Double(1, 4096) = aten::threshold_(%input.11, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[2]
-  %input.13 : Double(1, 4096) = aten::dropout(%input.12, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
-  %57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
-  %input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
-  %input : Double(1, 4096) = aten::threshold_(%input.14, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[5]
-  %60 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
-  %61 : Double(1, 1000) = aten::addmm(%16, %input, %60, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
-  return (%61);
+  %28 : float = prim::Constant[value=0](), scope: AlexNet/Sequential[features]/ReLU[1]
+  %input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %28, %28), scope: AlexNet/Sequential[features]/ReLU[1]
+  %30 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2]
+  %31 : int[] = prim::ListConstruct(%30, %30), scope: AlexNet/Sequential[features]/MaxPool2d[2]
+  %32 : Double(1, 64, 27, 27), %33 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2]
+  %input.3 : Double(1, 192, 27, 27) = aten::_convolution(%32, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3]
+  %input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %28, %28), scope: AlexNet/Sequential[features]/ReLU[4]
+  %36 : Double(1, 192, 13, 13), %37 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5]
+  %input.5 : Double(1, 384, 13, 13) = aten::_convolution(%36, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6]
+  %39 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %28, %28), scope: AlexNet/Sequential[features]/ReLU[7]
+  %input.6 : Double(1, 256, 13, 13) = aten::_convolution(%39, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8]
+  %41 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %28, %28), scope: AlexNet/Sequential[features]/ReLU[9]
+  %input.7 : Double(1, 256, 13, 13) = aten::_convolution(%41, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10]
+  %input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %28, %28), scope: AlexNet/Sequential[features]/ReLU[11]
+  %44 : Double(1, 256, 6, 6), %45 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12]
+  %46 : int = aten::size(%44, %24), scope: AlexNet
+  %47 : Long() = prim::NumToTensor(%46), scope: AlexNet
+  %48 : int = prim::TensorToNum(%47), scope: AlexNet
+  %49 : int = prim::Constant[value=9216](), scope: AlexNet
+  %50 : int[] = prim::ListConstruct(%48, %49), scope: AlexNet
+  %input.9 : Double(1, 9216) = aten::view(%44, %50), scope: AlexNet
+  %52 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0]
+  %input.10 : Double(1, 9216) = aten::dropout(%input.9, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[0]
+  %54 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1]
+  %input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %54, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1]
+  %input.12 : Double(1, 4096) = aten::threshold_(%input.11, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[2]
+  %input.13 : Double(1, 4096) = aten::dropout(%input.12, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[3]
+  %58 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4]
+  %input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %58, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4]
+  %input : Double(1, 4096) = aten::threshold_(%input.14, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[5]
+  %61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6]
+  %62 : Double(1, 1000) = aten::addmm(%16, %input, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6]
+  return (%62);
 }
index 540a13e..0ea0c49 100644 (file)
@@ -11,8 +11,10 @@ from torch.autograd.function import traceable
 from torch.testing import assert_allclose
 from torch.onnx import OperatorExportTypes
 from torch._six import inf, PY2
-from common_utils import (TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN,
-                          skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE)
+from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
+    skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
+    freeze_rng_state
+from test_nn import module_tests, new_module_tests
 from textwrap import dedent
 import os
 import io
@@ -41,6 +43,7 @@ from torch.jit import BatchTensor
 from test_module.future_div import div_int_future, div_float_future
 from test_module.no_future_div import div_int_nofuture, div_float_nofuture
 
+
 # load_tests from common_utils is used to automatically filter tests for
 # sharding on sandcastle. This line silences flake warnings
 load_tests = load_tests
@@ -446,9 +449,8 @@ class JitTestCase(TestCase):
 
     def runAndSaveRNG(self, func, inputs, kwargs=None):
         kwargs = kwargs if kwargs else {}
-        initial_rng_state = torch.get_rng_state()
-        results = func(*inputs, **kwargs)
-        torch.set_rng_state(initial_rng_state)
+        with freeze_rng_state():
+            results = func(*inputs, **kwargs)
         return results
 
 
@@ -9396,6 +9398,11 @@ EXCLUDE_SCRIPT = {
     'test_nn_max_unpool1d',
 }
 
+EXCLUDE_SCRIPT_MODULES = {
+    'test_nn_LPPool2d_norm',
+    'test_nn_LPPool1d_norm',
+}
+
 DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
     'test_nn_avg_pool2d',
     'test_nn_log_softmax',
@@ -10199,10 +10206,37 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=()
     post_add_test(test_name, skipTestIf, do_test)
 
 
-def add_nn_module_test(module_name, constructor_args, call_args,
-                       use_as_constant=False, skipTestIf=()):
+def add_nn_module_test(*args, **kwargs):
+    if 'module_name' in kwargs:
+        name = kwargs['module_name']
+    elif 'fullname' in kwargs:
+        name = kwargs['fullname']
+    elif 'constructor' in kwargs:
+        name = kwargs['constructor'].__name__
+
+    class_name = name.split("_")[0]
+
+    module = getattr(torch.nn, class_name, None)
+    if module is None or torch._jit_internal._weak_types.get(module) is None:
+        return
+
+    if 'desc' in kwargs and 'eval' in kwargs['desc']:
+        # eval() is not supported, so skip these tests
+        return
+
+    test_name = name
+    if 'desc' in kwargs:
+        test_name = "{}_{}".format(test_name, kwargs['desc'])
+    test_name = 'test_nn_{}'.format(test_name)
+
     def do_test(self):
-        nn_module = getattr(torch.nn, module_name)
+        if test_name in EXCLUDE_SCRIPT_MODULES:
+            return
+        if 'constructor' in kwargs:
+            nn_module = kwargs['constructor']
+        else:
+            nn_module = getattr(torch.nn, name)
+        constructor_args = kwargs.get('constructor_args', ())
 
         # Construct a script module that passes arguments through
         # to self.submodule
@@ -10215,7 +10249,7 @@ def add_nn_module_test(module_name, constructor_args, call_args,
             script = script_method_template.format(method_args, call)
 
             submodule_constants = []
-            if use_as_constant:
+            if kwargs.get('is_constant'):
                 submodule_constants = ['submodule']
 
             # Create module to use the script method
@@ -10241,13 +10275,16 @@ def add_nn_module_test(module_name, constructor_args, call_args,
             return module(*args)
 
         # Check against Python module as reference
-        args_variable, kwargs_variable = create_input(call_args)
+        if 'input_fn' in kwargs:
+            input_size = tuple(kwargs['input_fn']().size())
+        else:
+            input_size = kwargs['input_size']
+        args_variable, kwargs_variable = create_input((input_size,))
         f_args_variable = deepcopy(unpack_variables(args_variable))
 
         check_against_reference(self, create_script_module, create_nn_module, f_args_variable)
 
-    test_name = 'test_nn_{}'.format(module_name)
-    post_add_test(test_name, skipTestIf, do_test)
+    post_add_test(test_name, (), do_test)
 
 
 def post_add_test(test_name, skipTestIf, do_test):
@@ -10411,8 +10448,8 @@ for test in autograd_method_tests:
 for test in nn_functional_tests:
     add_nn_functional_test(*test)
 
-for test in nn_module_tests:
-    add_nn_module_test(*test)
+for test in module_tests + new_module_tests:
+    add_nn_module_test(**test)
 
 if __name__ == '__main__':
     run_tests()
index de61e68..f6aa527 100644 (file)
@@ -8321,7 +8321,7 @@ new_module_tests = [
     ),
     dict(
         module_name='LPPool2d',
-        constructor_args=(2, (2, 2), 2),
+        constructor_args=(2, 2, 2),
         input_size=(1, 3, 7, 7),
     ),
     dict(
@@ -9005,7 +9005,7 @@ new_module_tests = [
     ),
     dict(
         module_name='Threshold',
-        constructor_args=(2, 1),
+        constructor_args=(2., 1.),
         input_size=(),
         check_inplace=True,
         desc='threshold_value_scalar'
index a45751f..77843a1 100644 (file)
@@ -519,7 +519,7 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
 
 @torch._jit_internal.weak_script
 def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
-    # type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
+    # type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
     r"""Applies a 2D power-average pooling over an input signal composed of
     several input planes. If the sum of all inputs to the power of `p` is
     zero, the gradient is set to zero as well.
index 86fcf62..83ee942 100644 (file)
@@ -9,8 +9,6 @@ from ..._jit_internal import weak_module, weak_script_method
 
 @torch._jit_internal.weak_module
 class Threshold(Module):
-    __constants__ = ['threshold', 'value', 'inplace']
-
     r"""Thresholds each element of the input Tensor
 
     Threshold is defined as:
@@ -38,6 +36,7 @@ class Threshold(Module):
         >>> input = torch.randn(2)
         >>> output = m(input)
     """
+    __constants__ = ['threshold', 'value', 'inplace']
 
     def __init__(self, threshold, value, inplace=False):
         super(Threshold, self).__init__()
@@ -57,6 +56,7 @@ class Threshold(Module):
         )
 
 
+@torch._jit_internal.weak_module
 class ReLU(Threshold):
     r"""Applies the rectified linear unit function element-wise
     :math:`\text{ReLU}(x)= \max(0, x)`
@@ -79,7 +79,7 @@ class ReLU(Threshold):
     """
 
     def __init__(self, inplace=False):
-        super(ReLU, self).__init__(0, 0, inplace)
+        super(ReLU, self).__init__(0., 0., inplace)
 
     def extra_repr(self):
         inplace_str = 'inplace' if self.inplace else ''
@@ -143,6 +143,7 @@ class RReLU(Module):
         return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
 
 
+@torch._jit_internal.weak_module
 class Hardtanh(Module):
     r"""Applies the HardTanh function element-wise
 
@@ -179,8 +180,9 @@ class Hardtanh(Module):
         >>> input = torch.randn(2)
         >>> output = m(input)
     """
+    __constants__ = ['min_val', 'max_val', 'inplace']
 
-    def __init__(self, min_val=-1, max_val=1, inplace=False, min_value=None, max_value=None):
+    def __init__(self, min_val=-1., max_val=1., inplace=False, min_value=None, max_value=None):
         super(Hardtanh, self).__init__()
         if min_value is not None:
             warnings.warn("keyword argument min_value is deprecated and renamed to min_val")
@@ -194,6 +196,7 @@ class Hardtanh(Module):
         self.inplace = inplace
         assert self.max_val > self.min_val
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
 
@@ -204,6 +207,7 @@ class Hardtanh(Module):
         )
 
 
+@torch._jit_internal.weak_module
 class ReLU6(Hardtanh):
     r"""Applies the element-wise function:
 
@@ -228,7 +232,7 @@ class ReLU6(Hardtanh):
     """
 
     def __init__(self, inplace=False):
-        super(ReLU6, self).__init__(0, 6, inplace)
+        super(ReLU6, self).__init__(0., 6., inplace)
 
     def extra_repr(self):
         inplace_str = 'inplace' if self.inplace else ''
@@ -288,6 +292,7 @@ class Tanh(Module):
         return torch.tanh(input)
 
 
+@torch._jit_internal.weak_module
 class ELU(Module):
     r"""Applies the element-wise function:
 
@@ -311,12 +316,14 @@ class ELU(Module):
         >>> input = torch.randn(2)
         >>> output = m(input)
     """
+    __constants__ = ['alpha', 'inplace']
 
     def __init__(self, alpha=1., inplace=False):
         super(ELU, self).__init__()
         self.alpha = alpha
         self.inplace = inplace
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.elu(input, self.alpha, self.inplace)
 
@@ -325,6 +332,7 @@ class ELU(Module):
         return 'alpha={}{}'.format(self.alpha, inplace_str)
 
 
+@torch._jit_internal.weak_module
 class CELU(Module):
     r"""Applies the element-wise function:
 
@@ -353,12 +361,14 @@ class CELU(Module):
     .. _`Continuously Differentiable Exponential Linear Units`:
         https://arxiv.org/abs/1704.07483
     """
+    __constants__ = ['alpha', 'inplace']
 
     def __init__(self, alpha=1., inplace=False):
         super(CELU, self).__init__()
         self.alpha = alpha
         self.inplace = inplace
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.celu(input, self.alpha, self.inplace)
 
@@ -367,6 +377,7 @@ class CELU(Module):
         return 'alpha={}{}'.format(self.alpha, inplace_str)
 
 
+@torch._jit_internal.weak_module
 class SELU(Module):
     r"""Applied element-wise, as:
 
@@ -396,11 +407,13 @@ class SELU(Module):
 
     .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
     """
+    __constants__ = ['inplace']
 
     def __init__(self, inplace=False):
         super(SELU, self).__init__()
         self.inplace = inplace
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.selu(input, self.inplace)
 
@@ -409,6 +422,7 @@ class SELU(Module):
         return inplace_str
 
 
+@torch._jit_internal.weak_module
 class GLU(Module):
     r"""Applies the gated linear unit function
     :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
@@ -428,11 +442,13 @@ class GLU(Module):
         >>> input = torch.randn(4, 2)
         >>> output = m(input)
     """
+    __constants__ = ['dim']
 
     def __init__(self, dim=-1):
         super(GLU, self).__init__()
         self.dim = dim
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.glu(input, self.dim)
 
@@ -482,6 +498,7 @@ class Hardshrink(Module):
         return '{}'.format(self.lambd)
 
 
+@torch._jit_internal.weak_module
 class LeakyReLU(Module):
     r"""Applies the element-wise function:
 
@@ -515,12 +532,14 @@ class LeakyReLU(Module):
         >>> input = torch.randn(2)
         >>> output = m(input)
     """
+    __constants__ = ['inplace', 'negative_slope']
 
     def __init__(self, negative_slope=1e-2, inplace=False):
         super(LeakyReLU, self).__init__()
         self.negative_slope = negative_slope
         self.inplace = inplace
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.leaky_relu(input, self.negative_slope, self.inplace)
 
@@ -529,6 +548,7 @@ class LeakyReLU(Module):
         return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
 
 
+@torch._jit_internal.weak_module
 class LogSigmoid(Module):
     r"""Applies the element-wise function:
 
@@ -548,10 +568,12 @@ class LogSigmoid(Module):
         >>> output = m(input)
     """
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.logsigmoid(input)
 
 
+@torch._jit_internal.weak_module
 class Softplus(Module):
     r"""Applies the element-wise function:
 
@@ -581,12 +603,14 @@ class Softplus(Module):
         >>> input = torch.randn(2)
         >>> output = m(input)
     """
+    __constants__ = ['beta', 'threshold']
 
     def __init__(self, beta=1, threshold=20):
         super(Softplus, self).__init__()
         self.beta = beta
         self.threshold = threshold
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.softplus(input, self.beta, self.threshold)
 
@@ -753,6 +777,7 @@ class Tanhshrink(Module):
         return F.tanhshrink(input)
 
 
+@torch._jit_internal.weak_module
 class Softmin(Module):
     r"""Applies the Softmin function to an n-dimensional input Tensor
     rescaling them so that the elements of the n-dimensional output Tensor
@@ -779,15 +804,18 @@ class Softmin(Module):
         >>> input = torch.randn(2, 3)
         >>> output = m(input)
     """
+    __constants__ = ['dim']
 
     def __init__(self, dim=None):
         super(Softmin, self).__init__()
         self.dim = dim
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.softmin(input, self.dim, _stacklevel=5)
 
 
+@torch._jit_internal.weak_module
 class Softmax(Module):
     r"""Applies the Softmax function to an n-dimensional input Tensor
     rescaling them so that the elements of the n-dimensional output Tensor
@@ -821,6 +849,7 @@ class Softmax(Module):
         >>> input = torch.randn(2, 3)
         >>> output = m(input)
     """
+    __constants__ = ['dim']
 
     def __init__(self, dim=None):
         super(Softmax, self).__init__()
@@ -831,10 +860,12 @@ class Softmax(Module):
         if not hasattr(self, 'dim'):
             self.dim = None
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.softmax(input, self.dim, _stacklevel=5)
 
 
+@torch._jit_internal.weak_module
 class Softmax2d(Module):
     r"""Applies SoftMax over features to each spatial location.
 
@@ -857,11 +888,13 @@ class Softmax2d(Module):
         >>> output = m(input)
     """
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
         return F.softmax(input, 1, _stacklevel=5)
 
 
+@torch._jit_internal.weak_module
 class LogSoftmax(Module):
     r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
     input Tensor. The LogSoftmax formulation can be simplified as:
@@ -887,6 +920,7 @@ class LogSoftmax(Module):
         >>> input = torch.randn(2, 3)
         >>> output = m(input)
     """
+    __constants__ = ['dim']
 
     def __init__(self, dim=None):
         super(LogSoftmax, self).__init__()
@@ -897,5 +931,6 @@ class LogSoftmax(Module):
         if not hasattr(self, 'dim'):
             self.dim = None
 
+    @torch._jit_internal.weak_script_method
     def forward(self, input):
         return F.log_softmax(input, self.dim, _stacklevel=5)