Add InstanceNorm, Distance modules to Script
authorDavid Riazati <davidriazati@fb.com>
Fri, 30 Nov 2018 06:16:52 +0000 (22:16 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 30 Nov 2018 06:18:55 +0000 (22:18 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14551

Differential Revision: D13272741

Pulled By: driazati

fbshipit-source-id: 3e4fe870d0e268903757f3ae8a56100606906bce

test/test_nn.py
torch/nn/modules/activation.py
torch/nn/modules/distance.py
torch/nn/modules/instancenorm.py
torch/nn/modules/linear.py
torch/nn/modules/normalization.py
torch/nn/modules/padding.py

index 0746d68..25f8ee7 100644 (file)
@@ -8354,7 +8354,7 @@ new_module_tests = [
     ),
     dict(
         module_name='LocalResponseNorm',
-        constructor_args=(1, 1, 0.5, 2),
+        constructor_args=(1, 1., 0.5, 2.),
         input_size=(1, 5, 7, 7, 7),
         desc='3d_custom_params',
     ),
@@ -8391,17 +8391,17 @@ new_module_tests = [
     ),
     dict(
         module_name='ConstantPad1d',
-        constructor_args=((1, 2), 2),
+        constructor_args=((1, 2), 2.),
         input_size=(2, 3, 4)
     ),
     dict(
         module_name='ConstantPad2d',
-        constructor_args=((1, 2, 3, 4), 2),
+        constructor_args=((1, 2, 3, 4), 2.),
         input_size=(2, 3, 4, 4)
     ),
     dict(
         module_name='ConstantPad3d',
-        constructor_args=((1, 2, 3, 4, 1, 0), 2),
+        constructor_args=((1, 2, 3, 4, 1, 0), 2.),
         input_size=(2, 3, 4, 4, 5)
     ),
     dict(
index 83ee942..a7224a3 100644 (file)
@@ -7,7 +7,7 @@ from .. import functional as F
 from ..._jit_internal import weak_module, weak_script_method
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Threshold(Module):
     r"""Thresholds each element of the input Tensor
 
@@ -45,7 +45,7 @@ class Threshold(Module):
         self.inplace = inplace
         # TODO: check in THNN (if inplace == True, then assert value <= threshold)
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.threshold(input, self.threshold, self.value, self.inplace)
 
@@ -56,7 +56,7 @@ class Threshold(Module):
         )
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class ReLU(Threshold):
     r"""Applies the rectified linear unit function element-wise
     :math:`\text{ReLU}(x)= \max(0, x)`
@@ -86,7 +86,7 @@ class ReLU(Threshold):
         return inplace_str
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class RReLU(Module):
     r"""Applies the randomized leaky rectified liner unit function, element-wise,
     as described in the paper:
@@ -134,7 +134,7 @@ class RReLU(Module):
         self.upper = upper
         self.inplace = inplace
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
 
@@ -143,7 +143,7 @@ class RReLU(Module):
         return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Hardtanh(Module):
     r"""Applies the HardTanh function element-wise
 
@@ -196,7 +196,7 @@ class Hardtanh(Module):
         self.inplace = inplace
         assert self.max_val > self.min_val
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
 
@@ -207,7 +207,7 @@ class Hardtanh(Module):
         )
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class ReLU6(Hardtanh):
     r"""Applies the element-wise function:
 
@@ -239,7 +239,7 @@ class ReLU6(Hardtanh):
         return inplace_str
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Sigmoid(Module):
     r"""Applies the element-wise function:
 
@@ -261,12 +261,12 @@ class Sigmoid(Module):
         >>> output = m(input)
     """
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return torch.sigmoid(input)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Tanh(Module):
     r"""Applies the element-wise function:
 
@@ -287,12 +287,12 @@ class Tanh(Module):
         >>> output = m(input)
     """
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return torch.tanh(input)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class ELU(Module):
     r"""Applies the element-wise function:
 
@@ -323,7 +323,7 @@ class ELU(Module):
         self.alpha = alpha
         self.inplace = inplace
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.elu(input, self.alpha, self.inplace)
 
@@ -332,7 +332,7 @@ class ELU(Module):
         return 'alpha={}{}'.format(self.alpha, inplace_str)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class CELU(Module):
     r"""Applies the element-wise function:
 
@@ -368,7 +368,7 @@ class CELU(Module):
         self.alpha = alpha
         self.inplace = inplace
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.celu(input, self.alpha, self.inplace)
 
@@ -377,7 +377,7 @@ class CELU(Module):
         return 'alpha={}{}'.format(self.alpha, inplace_str)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class SELU(Module):
     r"""Applied element-wise, as:
 
@@ -413,7 +413,7 @@ class SELU(Module):
         super(SELU, self).__init__()
         self.inplace = inplace
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.selu(input, self.inplace)
 
@@ -422,7 +422,7 @@ class SELU(Module):
         return inplace_str
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class GLU(Module):
     r"""Applies the gated linear unit function
     :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
@@ -448,7 +448,7 @@ class GLU(Module):
         super(GLU, self).__init__()
         self.dim = dim
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.glu(input, self.dim)
 
@@ -456,7 +456,7 @@ class GLU(Module):
         return 'dim={}'.format(self.dim)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Hardshrink(Module):
     r"""Applies the hard shrinkage function element-wise:
 
@@ -490,7 +490,7 @@ class Hardshrink(Module):
         super(Hardshrink, self).__init__()
         self.lambd = lambd
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.hardshrink(input, self.lambd)
 
@@ -498,7 +498,7 @@ class Hardshrink(Module):
         return '{}'.format(self.lambd)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class LeakyReLU(Module):
     r"""Applies the element-wise function:
 
@@ -539,7 +539,7 @@ class LeakyReLU(Module):
         self.negative_slope = negative_slope
         self.inplace = inplace
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.leaky_relu(input, self.negative_slope, self.inplace)
 
@@ -548,7 +548,7 @@ class LeakyReLU(Module):
         return 'negative_slope={}{}'.format(self.negative_slope, inplace_str)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class LogSigmoid(Module):
     r"""Applies the element-wise function:
 
@@ -568,12 +568,12 @@ class LogSigmoid(Module):
         >>> output = m(input)
     """
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.logsigmoid(input)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Softplus(Module):
     r"""Applies the element-wise function:
 
@@ -610,7 +610,7 @@ class Softplus(Module):
         self.beta = beta
         self.threshold = threshold
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.softplus(input, self.beta, self.threshold)
 
@@ -618,7 +618,7 @@ class Softplus(Module):
         return 'beta={}, threshold={}'.format(self.beta, self.threshold)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Softshrink(Module):
     r"""Applies the soft shrinkage function elementwise:
 
@@ -652,7 +652,7 @@ class Softshrink(Module):
         super(Softshrink, self).__init__()
         self.lambd = lambd
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.softshrink(input, self.lambd)
 
@@ -660,7 +660,7 @@ class Softshrink(Module):
         return str(self.lambd)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class PReLU(Module):
     r"""Applies the element-wise function:
 
@@ -717,7 +717,7 @@ class PReLU(Module):
         super(PReLU, self).__init__()
         self.weight = Parameter(torch.Tensor(num_parameters).fill_(init))
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.prelu(input, self.weight)
 
@@ -725,7 +725,7 @@ class PReLU(Module):
         return 'num_parameters={}'.format(self.num_parameters)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Softsign(Module):
     r"""Applies the element-wise function:
 
@@ -746,12 +746,12 @@ class Softsign(Module):
         >>> output = m(input)
     """
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.softsign(input)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Tanhshrink(Module):
     r"""Applies the element-wise function:
 
@@ -772,12 +772,12 @@ class Tanhshrink(Module):
         >>> output = m(input)
     """
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.tanhshrink(input)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Softmin(Module):
     r"""Applies the Softmin function to an n-dimensional input Tensor
     rescaling them so that the elements of the n-dimensional output Tensor
@@ -810,12 +810,12 @@ class Softmin(Module):
         super(Softmin, self).__init__()
         self.dim = dim
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.softmin(input, self.dim, _stacklevel=5)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Softmax(Module):
     r"""Applies the Softmax function to an n-dimensional input Tensor
     rescaling them so that the elements of the n-dimensional output Tensor
@@ -860,12 +860,12 @@ class Softmax(Module):
         if not hasattr(self, 'dim'):
             self.dim = None
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.softmax(input, self.dim, _stacklevel=5)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class Softmax2d(Module):
     r"""Applies SoftMax over features to each spatial location.
 
@@ -888,13 +888,13 @@ class Softmax2d(Module):
         >>> output = m(input)
     """
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
         return F.softmax(input, 1, _stacklevel=5)
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class LogSoftmax(Module):
     r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
     input Tensor. The LogSoftmax formulation can be simplified as:
@@ -931,6 +931,6 @@ class LogSoftmax(Module):
         if not hasattr(self, 'dim'):
             self.dim = None
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.log_softmax(input, self.dim, _stacklevel=5)
index 2ce9c13..45e4ae7 100644 (file)
@@ -4,7 +4,7 @@ from .. import functional as F
 from ..._jit_internal import weak_module, weak_script_method
 
 
-@torch._jit_internal.weak_module
+@weak_module
 class PairwiseDistance(Module):
     r"""
     Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
@@ -39,11 +39,12 @@ class PairwiseDistance(Module):
         self.eps = eps
         self.keepdim = keepdim
 
-    @torch._jit_internal.weak_script_method
+    @weak_script_method
     def forward(self, x1, x2):
         return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
 
 
+@weak_module
 class CosineSimilarity(Module):
     r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim.
 
@@ -67,10 +68,13 @@ class CosineSimilarity(Module):
         >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
         >>> output = cos(input1, input2)
     """
+    __constants__ = ['dim', 'eps']
+
     def __init__(self, dim=1, eps=1e-8):
         super(CosineSimilarity, self).__init__()
         self.dim = dim
         self.eps = eps
 
+    @weak_script_method
     def forward(self, x1, x2):
         return F.cosine_similarity(x1, x2, self.dim, self.eps)
index 829e1ce..d5b8427 100644 (file)
@@ -1,13 +1,18 @@
 from .batchnorm import _BatchNorm
 from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
 
 
 class _InstanceNorm(_BatchNorm):
+    __constants__ = ['running_mean', 'running_var', 'weight', 'bias',
+                     'training', 'track_running_stats', 'momentum', 'eps']
+
     def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False,
                  track_running_stats=False):
         super(_InstanceNorm, self).__init__(
             num_features, eps, momentum, affine, track_running_stats)
 
+    @weak_script_method
     def _check_input_dim(self, input):
         raise NotImplementedError
 
@@ -41,6 +46,7 @@ class _InstanceNorm(_BatchNorm):
             state_dict, prefix, local_metadata, strict,
             missing_keys, unexpected_keys, error_msgs)
 
+    @weak_script_method
     def forward(self, input):
         self._check_input_dim(input)
 
@@ -49,6 +55,7 @@ class _InstanceNorm(_BatchNorm):
             self.training or not self.track_running_stats, self.momentum, self.eps)
 
 
+@weak_module
 class InstanceNorm1d(_InstanceNorm):
     r"""Applies Instance Normalization over a 2D or 3D input (a mini-batch of 1D
     inputs with optional additional channel dimension) as described in the paper
@@ -117,12 +124,14 @@ class InstanceNorm1d(_InstanceNorm):
         https://arxiv.org/abs/1607.08022
     """
 
+    @weak_script_method
     def _check_input_dim(self, input):
         if input.dim() != 2 and input.dim() != 3:
             raise ValueError('expected 2D or 3D input (got {}D input)'
                              .format(input.dim()))
 
 
+@weak_module
 class InstanceNorm2d(_InstanceNorm):
     r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs
     with additional channel dimension) as described in the paper
@@ -191,12 +200,14 @@ class InstanceNorm2d(_InstanceNorm):
         https://arxiv.org/abs/1607.08022
     """
 
+    @weak_script_method
     def _check_input_dim(self, input):
         if input.dim() != 4:
             raise ValueError('expected 4D input (got {}D input)'
                              .format(input.dim()))
 
 
+@weak_module
 class InstanceNorm3d(_InstanceNorm):
     r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs
     with additional channel dimension) as described in the paper
@@ -265,6 +276,7 @@ class InstanceNorm3d(_InstanceNorm):
         https://arxiv.org/abs/1607.08022
     """
 
+    @weak_script_method
     def _check_input_dim(self, input):
         if input.dim() != 5:
             raise ValueError('expected 5D input (got {}D input)'
index 3cd4661..3ed374f 100644 (file)
@@ -105,6 +105,7 @@ class Bilinear(Module):
         >>> print(output.size())
         torch.Size([128, 40])
     """
+    __constants__ = ['in1_features', 'in2_features', 'out_features']
 
     def __init__(self, in1_features, in2_features, out_features, bias=True):
         super(Bilinear, self).__init__()
index b3bf05f..b730821 100644 (file)
@@ -5,8 +5,10 @@ from .module import Module
 from .batchnorm import _BatchNorm
 from .. import functional as F
 from .. import init
+from ..._jit_internal import weak_module, weak_script_method
 
 
+@weak_module
 class LocalResponseNorm(Module):
     r"""Applies local response normalization over an input signal composed
     of several input planes, where channels occupy the second dimension.
@@ -35,14 +37,16 @@ class LocalResponseNorm(Module):
         >>> output_4d = lrn(signal_4d)
 
     """
+    __constants__ = ['size', 'alpha', 'beta', 'k']
 
-    def __init__(self, size, alpha=1e-4, beta=0.75, k=1):
+    def __init__(self, size, alpha=1e-4, beta=0.75, k=1.):
         super(LocalResponseNorm, self).__init__()
         self.size = size
         self.alpha = alpha
         self.beta = beta
         self.k = k
 
+    @weak_script_method
     def forward(self, input):
         return F.local_response_norm(input, self.size, self.alpha, self.beta,
                                      self.k)
@@ -68,6 +72,7 @@ class CrossMapLRN2d(Module):
         return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__)
 
 
+@weak_module
 class LayerNorm(Module):
     r"""Applies Layer Normalization over a mini-batch of inputs as described in
     the paper `Layer Normalization`_ .
@@ -145,6 +150,7 @@ class LayerNorm(Module):
             init.ones_(self.weight)
             init.zeros_(self.bias)
 
+    @weak_script_method
     def forward(self, input):
         return F.layer_norm(
             input, self.normalized_shape, self.weight, self.bias, self.eps)
index b9d63a0..9ff8d59 100644 (file)
@@ -1,17 +1,21 @@
 from .module import Module
 from .utils import _pair, _quadruple, _ntuple
 from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
 
 
 # TODO: grad_output size asserts in THNN
 
 
+@weak_module
 class _ConstantPadNd(Module):
+    __constants__ = ['padding', 'value']
 
     def __init__(self, value):
         super(_ConstantPadNd, self).__init__()
         self.value = value
 
+    @weak_script_method
     def forward(self, input):
         return F.pad(input, self.padding, 'constant', self.value)
 
@@ -19,6 +23,7 @@ class _ConstantPadNd(Module):
         return 'padding={}, value={}'.format(self.padding, self.value)
 
 
+@weak_module
 class ConstantPad1d(_ConstantPadNd):
     r"""Pads the input tensor boundaries with a constant value.
 
@@ -67,6 +72,7 @@ class ConstantPad1d(_ConstantPadNd):
         self.padding = _pair(padding)
 
 
+@weak_module
 class ConstantPad2d(_ConstantPadNd):
     r"""Pads the input tensor boundaries with a constant value.
 
@@ -114,12 +120,14 @@ class ConstantPad2d(_ConstantPadNd):
                  [ 3.5000,  3.5000,  3.5000,  3.5000,  3.5000]]])
 
     """
+    __constants__ = ['padding', 'value']
 
     def __init__(self, padding, value):
         super(ConstantPad2d, self).__init__(value)
         self.padding = _quadruple(padding)
 
 
+@weak_module
 class ConstantPad3d(_ConstantPadNd):
     r"""Pads the input tensor boundaries with a constant value.
 
@@ -155,8 +163,11 @@ class ConstantPad3d(_ConstantPadNd):
         self.padding = _ntuple(6)(padding)
 
 
+@weak_module
 class _ReflectionPadNd(Module):
+    __constants__ = ['padding']
 
+    @weak_script_method
     def forward(self, input):
         return F.pad(input, self.padding, 'reflect')
 
@@ -164,6 +175,7 @@ class _ReflectionPadNd(Module):
         return '{}'.format(self.padding)
 
 
+@weak_module
 class ReflectionPad1d(_ReflectionPadNd):
     r"""Pads the input tensor using the reflection of the input boundary.
 
@@ -205,6 +217,7 @@ class ReflectionPad1d(_ReflectionPadNd):
         self.padding = _pair(padding)
 
 
+@weak_module
 class ReflectionPad2d(_ReflectionPadNd):
     r"""Pads the input tensor using the reflection of the input boundary.
 
@@ -254,8 +267,11 @@ class ReflectionPad2d(_ReflectionPadNd):
         self.padding = _quadruple(padding)
 
 
+@weak_module
 class _ReplicationPadNd(Module):
+    __constants__ = ['padding']
 
+    @weak_script_method
     def forward(self, input):
         return F.pad(input, self.padding, 'replicate')
 
@@ -263,6 +279,7 @@ class _ReplicationPadNd(Module):
         return '{}'.format(self.padding)
 
 
+@weak_module
 class ReplicationPad1d(_ReplicationPadNd):
     r"""Pads the input tensor using replication of the input boundary.
 
@@ -301,6 +318,7 @@ class ReplicationPad1d(_ReplicationPadNd):
         self.padding = _pair(padding)
 
 
+@weak_module
 class ReplicationPad2d(_ReplicationPadNd):
     r"""Pads the input tensor using replication of the input boundary.
 
@@ -349,6 +367,7 @@ class ReplicationPad2d(_ReplicationPadNd):
         self.padding = _quadruple(padding)
 
 
+@weak_module
 class ReplicationPad3d(_ReplicationPadNd):
     r"""Pads the input tensor using replication of the input boundary.
 
@@ -384,6 +403,7 @@ class ReplicationPad3d(_ReplicationPadNd):
         self.padding = _ntuple(6)(padding)
 
 
+@weak_module
 class ZeroPad2d(ConstantPad2d):
     r"""Pads the input tensor boundaries with zero.
 
@@ -428,4 +448,4 @@ class ZeroPad2d(ConstantPad2d):
     """
 
     def __init__(self, padding):
-        super(ZeroPad2d, self).__init__(padding, 0)
+        super(ZeroPad2d, self).__init__(padding, 0.)