From 3d98810fbd3c8068e4dbbd3a686805a12c861f8c Mon Sep 17 00:00:00 2001 From: David Riazati Date: Wed, 28 Nov 2018 00:21:01 -0800 Subject: [PATCH] Revert D13192230: [pytorch][PR] [jit] Use nn module tests in test_jit Differential Revision: D13192230 Original commit changeset: 36488960b6c9 fbshipit-source-id: 63b68bd909b9ef0548f52c986c84f549aecb8909 --- test/common_nn.py | 4 +- test/expect/TestJit.test_alexnet.expect | 65 ++++++++++++++++----------------- test/test_jit.py | 65 +++++++-------------------------- test/test_nn.py | 4 +- torch/nn/functional.py | 2 +- torch/nn/modules/activation.py | 45 +++-------------------- 6 files changed, 56 insertions(+), 129 deletions(-) diff --git a/test/common_nn.py b/test/common_nn.py index 27f1571..c56ee4d 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -51,14 +51,14 @@ module_tests = [ ), dict( module_name='Threshold', - constructor_args=(2., 1.), + constructor_args=(2, 1), input_size=(2, 3, 4, 5), check_inplace=True, desc='threshold_value' ), dict( module_name='Threshold', - constructor_args=(2., 10.), + constructor_args=(2, 10), input_size=(2, 3, 4, 5), desc='large_value' ), diff --git a/test/expect/TestJit.test_alexnet.expect b/test/expect/TestJit.test_alexnet.expect index 0652078..bd9e335 100644 --- a/test/expect/TestJit.test_alexnet.expect +++ b/test/expect/TestJit.test_alexnet.expect @@ -26,37 +26,36 @@ graph(%0 : Double(1, 3, 224, 224) %25 : int[] = prim::ListConstruct(%24, %24), scope: AlexNet/Sequential[features]/Conv2d[0] %26 : bool = prim::Constant[value=1](), scope: AlexNet/Sequential[features]/Conv2d[0] %input.1 : Double(1, 64, 55, 55) = aten::_convolution(%0, %1, %2, %18, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[0] - %28 : float = prim::Constant[value=0](), scope: AlexNet/Sequential[features]/ReLU[1] - %input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %28, %28), scope: AlexNet/Sequential[features]/ReLU[1] - %30 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2] - %31 : int[] = prim::ListConstruct(%30, %30), scope: AlexNet/Sequential[features]/MaxPool2d[2] - %32 : Double(1, 64, 27, 27), %33 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2] - %input.3 : Double(1, 192, 27, 27) = aten::_convolution(%32, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3] - %input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %28, %28), scope: AlexNet/Sequential[features]/ReLU[4] - %36 : Double(1, 192, 13, 13), %37 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5] - %input.5 : Double(1, 384, 13, 13) = aten::_convolution(%36, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6] - %39 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %28, %28), scope: AlexNet/Sequential[features]/ReLU[7] - %input.6 : Double(1, 256, 13, 13) = aten::_convolution(%39, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8] - %41 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %28, %28), scope: AlexNet/Sequential[features]/ReLU[9] - %input.7 : Double(1, 256, 13, 13) = aten::_convolution(%41, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10] - %input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %28, %28), scope: AlexNet/Sequential[features]/ReLU[11] - %44 : Double(1, 256, 6, 6), %45 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %31, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12] - %46 : int = aten::size(%44, %24), scope: AlexNet - %47 : Long() = prim::NumToTensor(%46), scope: AlexNet - %48 : int = prim::TensorToNum(%47), scope: AlexNet - %49 : int = prim::Constant[value=9216](), scope: AlexNet - %50 : int[] = prim::ListConstruct(%48, %49), scope: AlexNet - %input.9 : Double(1, 9216) = aten::view(%44, %50), scope: AlexNet - %52 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0] - %input.10 : Double(1, 9216) = aten::dropout(%input.9, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[0] - %54 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] - %input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %54, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1] - %input.12 : Double(1, 4096) = aten::threshold_(%input.11, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[2] - %input.13 : Double(1, 4096) = aten::dropout(%input.12, %52, %26), scope: AlexNet/Sequential[classifier]/Dropout[3] - %58 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] - %input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %58, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4] - %input : Double(1, 4096) = aten::threshold_(%input.14, %28, %28), scope: AlexNet/Sequential[classifier]/ReLU[5] - %61 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] - %62 : Double(1, 1000) = aten::addmm(%16, %input, %61, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6] - return (%62); + %input.2 : Double(1, 64, 55, 55) = aten::threshold_(%input.1, %24, %24), scope: AlexNet/Sequential[features]/ReLU[1] + %29 : int = prim::Constant[value=3](), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %30 : int[] = prim::ListConstruct(%29, %29), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %31 : Double(1, 64, 27, 27), %32 : Long(1, 64, 27, 27) = aten::max_pool2d_with_indices(%input.2, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[2] + %input.3 : Double(1, 192, 27, 27) = aten::_convolution(%31, %3, %4, %22, %20, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[3] + %input.4 : Double(1, 192, 27, 27) = aten::threshold_(%input.3, %24, %24), scope: AlexNet/Sequential[features]/ReLU[4] + %35 : Double(1, 192, 13, 13), %36 : Long(1, 192, 13, 13) = aten::max_pool2d_with_indices(%input.4, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[5] + %input.5 : Double(1, 384, 13, 13) = aten::_convolution(%35, %5, %6, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[6] + %38 : Double(1, 384, 13, 13) = aten::threshold_(%input.5, %24, %24), scope: AlexNet/Sequential[features]/ReLU[7] + %input.6 : Double(1, 256, 13, 13) = aten::_convolution(%38, %7, %8, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[8] + %40 : Double(1, 256, 13, 13) = aten::threshold_(%input.6, %24, %24), scope: AlexNet/Sequential[features]/ReLU[9] + %input.7 : Double(1, 256, 13, 13) = aten::_convolution(%40, %9, %10, %22, %22, %22, %23, %25, %21, %23, %23, %26), scope: AlexNet/Sequential[features]/Conv2d[10] + %input.8 : Double(1, 256, 13, 13) = aten::threshold_(%input.7, %24, %24), scope: AlexNet/Sequential[features]/ReLU[11] + %43 : Double(1, 256, 6, 6), %44 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices(%input.8, %30, %20, %25, %22, %23), scope: AlexNet/Sequential[features]/MaxPool2d[12] + %45 : int = aten::size(%43, %24), scope: AlexNet + %46 : Long() = prim::NumToTensor(%45), scope: AlexNet + %47 : int = prim::TensorToNum(%46), scope: AlexNet + %48 : int = prim::Constant[value=9216](), scope: AlexNet + %49 : int[] = prim::ListConstruct(%47, %48), scope: AlexNet + %input.9 : Double(1, 9216) = aten::view(%43, %49), scope: AlexNet + %51 : float = prim::Constant[value=0.5](), scope: AlexNet/Sequential[classifier]/Dropout[0] + %input.10 : Double(1, 9216) = aten::dropout(%input.9, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[0] + %53 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] + %input.11 : Double(1, 4096) = aten::addmm(%12, %input.10, %53, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[1] + %input.12 : Double(1, 4096) = aten::threshold_(%input.11, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[2] + %input.13 : Double(1, 4096) = aten::dropout(%input.12, %51, %26), scope: AlexNet/Sequential[classifier]/Dropout[3] + %57 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] + %input.14 : Double(1, 4096) = aten::addmm(%14, %input.13, %57, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[4] + %input : Double(1, 4096) = aten::threshold_(%input.14, %24, %24), scope: AlexNet/Sequential[classifier]/ReLU[5] + %60 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] + %61 : Double(1, 1000) = aten::addmm(%16, %input, %60, %21, %21), scope: AlexNet/Sequential[classifier]/Linear[6] + return (%61); } diff --git a/test/test_jit.py b/test/test_jit.py index 0ea0c49..540a13e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -11,10 +11,8 @@ from torch.autograd.function import traceable from torch.testing import assert_allclose from torch.onnx import OperatorExportTypes from torch._six import inf, PY2 -from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ - skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \ - freeze_rng_state -from test_nn import module_tests, new_module_tests +from common_utils import (TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, + skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE) from textwrap import dedent import os import io @@ -43,7 +41,6 @@ from torch.jit import BatchTensor from test_module.future_div import div_int_future, div_float_future from test_module.no_future_div import div_int_nofuture, div_float_nofuture - # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -449,8 +446,9 @@ class JitTestCase(TestCase): def runAndSaveRNG(self, func, inputs, kwargs=None): kwargs = kwargs if kwargs else {} - with freeze_rng_state(): - results = func(*inputs, **kwargs) + initial_rng_state = torch.get_rng_state() + results = func(*inputs, **kwargs) + torch.set_rng_state(initial_rng_state) return results @@ -9398,11 +9396,6 @@ EXCLUDE_SCRIPT = { 'test_nn_max_unpool1d', } -EXCLUDE_SCRIPT_MODULES = { - 'test_nn_LPPool2d_norm', - 'test_nn_LPPool1d_norm', -} - DISABLE_AUTODIFF_SUBGRAPH_INLINING = { 'test_nn_avg_pool2d', 'test_nn_log_softmax', @@ -10206,37 +10199,10 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=() post_add_test(test_name, skipTestIf, do_test) -def add_nn_module_test(*args, **kwargs): - if 'module_name' in kwargs: - name = kwargs['module_name'] - elif 'fullname' in kwargs: - name = kwargs['fullname'] - elif 'constructor' in kwargs: - name = kwargs['constructor'].__name__ - - class_name = name.split("_")[0] - - module = getattr(torch.nn, class_name, None) - if module is None or torch._jit_internal._weak_types.get(module) is None: - return - - if 'desc' in kwargs and 'eval' in kwargs['desc']: - # eval() is not supported, so skip these tests - return - - test_name = name - if 'desc' in kwargs: - test_name = "{}_{}".format(test_name, kwargs['desc']) - test_name = 'test_nn_{}'.format(test_name) - +def add_nn_module_test(module_name, constructor_args, call_args, + use_as_constant=False, skipTestIf=()): def do_test(self): - if test_name in EXCLUDE_SCRIPT_MODULES: - return - if 'constructor' in kwargs: - nn_module = kwargs['constructor'] - else: - nn_module = getattr(torch.nn, name) - constructor_args = kwargs.get('constructor_args', ()) + nn_module = getattr(torch.nn, module_name) # Construct a script module that passes arguments through # to self.submodule @@ -10249,7 +10215,7 @@ def add_nn_module_test(*args, **kwargs): script = script_method_template.format(method_args, call) submodule_constants = [] - if kwargs.get('is_constant'): + if use_as_constant: submodule_constants = ['submodule'] # Create module to use the script method @@ -10275,16 +10241,13 @@ def add_nn_module_test(*args, **kwargs): return module(*args) # Check against Python module as reference - if 'input_fn' in kwargs: - input_size = tuple(kwargs['input_fn']().size()) - else: - input_size = kwargs['input_size'] - args_variable, kwargs_variable = create_input((input_size,)) + args_variable, kwargs_variable = create_input(call_args) f_args_variable = deepcopy(unpack_variables(args_variable)) check_against_reference(self, create_script_module, create_nn_module, f_args_variable) - post_add_test(test_name, (), do_test) + test_name = 'test_nn_{}'.format(module_name) + post_add_test(test_name, skipTestIf, do_test) def post_add_test(test_name, skipTestIf, do_test): @@ -10448,8 +10411,8 @@ for test in autograd_method_tests: for test in nn_functional_tests: add_nn_functional_test(*test) -for test in module_tests + new_module_tests: - add_nn_module_test(**test) +for test in nn_module_tests: + add_nn_module_test(*test) if __name__ == '__main__': run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index f6aa527..de61e68 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8321,7 +8321,7 @@ new_module_tests = [ ), dict( module_name='LPPool2d', - constructor_args=(2, 2, 2), + constructor_args=(2, (2, 2), 2), input_size=(1, 3, 7, 7), ), dict( @@ -9005,7 +9005,7 @@ new_module_tests = [ ), dict( module_name='Threshold', - constructor_args=(2., 1.), + constructor_args=(2, 1), input_size=(), check_inplace=True, desc='threshold_value_scalar' diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 77843a1..a45751f 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -519,7 +519,7 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, @torch._jit_internal.weak_script def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): - # type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor + # type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor r"""Applies a 2D power-average pooling over an input signal composed of several input planes. If the sum of all inputs to the power of `p` is zero, the gradient is set to zero as well. diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 83ee942..86fcf62 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -9,6 +9,8 @@ from ..._jit_internal import weak_module, weak_script_method @torch._jit_internal.weak_module class Threshold(Module): + __constants__ = ['threshold', 'value', 'inplace'] + r"""Thresholds each element of the input Tensor Threshold is defined as: @@ -36,7 +38,6 @@ class Threshold(Module): >>> input = torch.randn(2) >>> output = m(input) """ - __constants__ = ['threshold', 'value', 'inplace'] def __init__(self, threshold, value, inplace=False): super(Threshold, self).__init__() @@ -56,7 +57,6 @@ class Threshold(Module): ) -@torch._jit_internal.weak_module class ReLU(Threshold): r"""Applies the rectified linear unit function element-wise :math:`\text{ReLU}(x)= \max(0, x)` @@ -79,7 +79,7 @@ class ReLU(Threshold): """ def __init__(self, inplace=False): - super(ReLU, self).__init__(0., 0., inplace) + super(ReLU, self).__init__(0, 0, inplace) def extra_repr(self): inplace_str = 'inplace' if self.inplace else '' @@ -143,7 +143,6 @@ class RReLU(Module): return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) -@torch._jit_internal.weak_module class Hardtanh(Module): r"""Applies the HardTanh function element-wise @@ -180,9 +179,8 @@ class Hardtanh(Module): >>> input = torch.randn(2) >>> output = m(input) """ - __constants__ = ['min_val', 'max_val', 'inplace'] - def __init__(self, min_val=-1., max_val=1., inplace=False, min_value=None, max_value=None): + def __init__(self, min_val=-1, max_val=1, inplace=False, min_value=None, max_value=None): super(Hardtanh, self).__init__() if min_value is not None: warnings.warn("keyword argument min_value is deprecated and renamed to min_val") @@ -196,7 +194,6 @@ class Hardtanh(Module): self.inplace = inplace assert self.max_val > self.min_val - @torch._jit_internal.weak_script_method def forward(self, input): return F.hardtanh(input, self.min_val, self.max_val, self.inplace) @@ -207,7 +204,6 @@ class Hardtanh(Module): ) -@torch._jit_internal.weak_module class ReLU6(Hardtanh): r"""Applies the element-wise function: @@ -232,7 +228,7 @@ class ReLU6(Hardtanh): """ def __init__(self, inplace=False): - super(ReLU6, self).__init__(0., 6., inplace) + super(ReLU6, self).__init__(0, 6, inplace) def extra_repr(self): inplace_str = 'inplace' if self.inplace else '' @@ -292,7 +288,6 @@ class Tanh(Module): return torch.tanh(input) -@torch._jit_internal.weak_module class ELU(Module): r"""Applies the element-wise function: @@ -316,14 +311,12 @@ class ELU(Module): >>> input = torch.randn(2) >>> output = m(input) """ - __constants__ = ['alpha', 'inplace'] def __init__(self, alpha=1., inplace=False): super(ELU, self).__init__() self.alpha = alpha self.inplace = inplace - @torch._jit_internal.weak_script_method def forward(self, input): return F.elu(input, self.alpha, self.inplace) @@ -332,7 +325,6 @@ class ELU(Module): return 'alpha={}{}'.format(self.alpha, inplace_str) -@torch._jit_internal.weak_module class CELU(Module): r"""Applies the element-wise function: @@ -361,14 +353,12 @@ class CELU(Module): .. _`Continuously Differentiable Exponential Linear Units`: https://arxiv.org/abs/1704.07483 """ - __constants__ = ['alpha', 'inplace'] def __init__(self, alpha=1., inplace=False): super(CELU, self).__init__() self.alpha = alpha self.inplace = inplace - @torch._jit_internal.weak_script_method def forward(self, input): return F.celu(input, self.alpha, self.inplace) @@ -377,7 +367,6 @@ class CELU(Module): return 'alpha={}{}'.format(self.alpha, inplace_str) -@torch._jit_internal.weak_module class SELU(Module): r"""Applied element-wise, as: @@ -407,13 +396,11 @@ class SELU(Module): .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 """ - __constants__ = ['inplace'] def __init__(self, inplace=False): super(SELU, self).__init__() self.inplace = inplace - @torch._jit_internal.weak_script_method def forward(self, input): return F.selu(input, self.inplace) @@ -422,7 +409,6 @@ class SELU(Module): return inplace_str -@torch._jit_internal.weak_module class GLU(Module): r"""Applies the gated linear unit function :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half @@ -442,13 +428,11 @@ class GLU(Module): >>> input = torch.randn(4, 2) >>> output = m(input) """ - __constants__ = ['dim'] def __init__(self, dim=-1): super(GLU, self).__init__() self.dim = dim - @torch._jit_internal.weak_script_method def forward(self, input): return F.glu(input, self.dim) @@ -498,7 +482,6 @@ class Hardshrink(Module): return '{}'.format(self.lambd) -@torch._jit_internal.weak_module class LeakyReLU(Module): r"""Applies the element-wise function: @@ -532,14 +515,12 @@ class LeakyReLU(Module): >>> input = torch.randn(2) >>> output = m(input) """ - __constants__ = ['inplace', 'negative_slope'] def __init__(self, negative_slope=1e-2, inplace=False): super(LeakyReLU, self).__init__() self.negative_slope = negative_slope self.inplace = inplace - @torch._jit_internal.weak_script_method def forward(self, input): return F.leaky_relu(input, self.negative_slope, self.inplace) @@ -548,7 +529,6 @@ class LeakyReLU(Module): return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) -@torch._jit_internal.weak_module class LogSigmoid(Module): r"""Applies the element-wise function: @@ -568,12 +548,10 @@ class LogSigmoid(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method def forward(self, input): return F.logsigmoid(input) -@torch._jit_internal.weak_module class Softplus(Module): r"""Applies the element-wise function: @@ -603,14 +581,12 @@ class Softplus(Module): >>> input = torch.randn(2) >>> output = m(input) """ - __constants__ = ['beta', 'threshold'] def __init__(self, beta=1, threshold=20): super(Softplus, self).__init__() self.beta = beta self.threshold = threshold - @torch._jit_internal.weak_script_method def forward(self, input): return F.softplus(input, self.beta, self.threshold) @@ -777,7 +753,6 @@ class Tanhshrink(Module): return F.tanhshrink(input) -@torch._jit_internal.weak_module class Softmin(Module): r"""Applies the Softmin function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor @@ -804,18 +779,15 @@ class Softmin(Module): >>> input = torch.randn(2, 3) >>> output = m(input) """ - __constants__ = ['dim'] def __init__(self, dim=None): super(Softmin, self).__init__() self.dim = dim - @torch._jit_internal.weak_script_method def forward(self, input): return F.softmin(input, self.dim, _stacklevel=5) -@torch._jit_internal.weak_module class Softmax(Module): r"""Applies the Softmax function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor @@ -849,7 +821,6 @@ class Softmax(Module): >>> input = torch.randn(2, 3) >>> output = m(input) """ - __constants__ = ['dim'] def __init__(self, dim=None): super(Softmax, self).__init__() @@ -860,12 +831,10 @@ class Softmax(Module): if not hasattr(self, 'dim'): self.dim = None - @torch._jit_internal.weak_script_method def forward(self, input): return F.softmax(input, self.dim, _stacklevel=5) -@torch._jit_internal.weak_module class Softmax2d(Module): r"""Applies SoftMax over features to each spatial location. @@ -888,13 +857,11 @@ class Softmax2d(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method def forward(self, input): assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' return F.softmax(input, 1, _stacklevel=5) -@torch._jit_internal.weak_module class LogSoftmax(Module): r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. The LogSoftmax formulation can be simplified as: @@ -920,7 +887,6 @@ class LogSoftmax(Module): >>> input = torch.randn(2, 3) >>> output = m(input) """ - __constants__ = ['dim'] def __init__(self, dim=None): super(LogSoftmax, self).__init__() @@ -931,6 +897,5 @@ class LogSoftmax(Module): if not hasattr(self, 'dim'): self.dim = None - @torch._jit_internal.weak_script_method def forward(self, input): return F.log_softmax(input, self.dim, _stacklevel=5) -- 2.7.4