From: David Riazati Date: Fri, 30 Nov 2018 06:16:52 +0000 (-0800) Subject: Add InstanceNorm, Distance modules to Script X-Git-Tag: submit/tizen/20210715.075526~2588 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1f6d9f44fc42ca5ce0734206be7b3b963f41fb3e;p=platform%2Fupstream%2Fpytorch.git Add InstanceNorm, Distance modules to Script Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14551 Differential Revision: D13272741 Pulled By: driazati fbshipit-source-id: 3e4fe870d0e268903757f3ae8a56100606906bce --- diff --git a/test/test_nn.py b/test/test_nn.py index 0746d68477..25f8ee7070 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8354,7 +8354,7 @@ new_module_tests = [ ), dict( module_name='LocalResponseNorm', - constructor_args=(1, 1, 0.5, 2), + constructor_args=(1, 1., 0.5, 2.), input_size=(1, 5, 7, 7, 7), desc='3d_custom_params', ), @@ -8391,17 +8391,17 @@ new_module_tests = [ ), dict( module_name='ConstantPad1d', - constructor_args=((1, 2), 2), + constructor_args=((1, 2), 2.), input_size=(2, 3, 4) ), dict( module_name='ConstantPad2d', - constructor_args=((1, 2, 3, 4), 2), + constructor_args=((1, 2, 3, 4), 2.), input_size=(2, 3, 4, 4) ), dict( module_name='ConstantPad3d', - constructor_args=((1, 2, 3, 4, 1, 0), 2), + constructor_args=((1, 2, 3, 4, 1, 0), 2.), input_size=(2, 3, 4, 4, 5) ), dict( diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 83ee94265b..a7224a355b 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -7,7 +7,7 @@ from .. import functional as F from ..._jit_internal import weak_module, weak_script_method -@torch._jit_internal.weak_module +@weak_module class Threshold(Module): r"""Thresholds each element of the input Tensor @@ -45,7 +45,7 @@ class Threshold(Module): self.inplace = inplace # TODO: check in THNN (if inplace == True, then assert value <= threshold) - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.threshold(input, self.threshold, self.value, self.inplace) @@ -56,7 +56,7 @@ class Threshold(Module): ) -@torch._jit_internal.weak_module +@weak_module class ReLU(Threshold): r"""Applies the rectified linear unit function element-wise :math:`\text{ReLU}(x)= \max(0, x)` @@ -86,7 +86,7 @@ class ReLU(Threshold): return inplace_str -@torch._jit_internal.weak_module +@weak_module class RReLU(Module): r"""Applies the randomized leaky rectified liner unit function, element-wise, as described in the paper: @@ -134,7 +134,7 @@ class RReLU(Module): self.upper = upper self.inplace = inplace - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.rrelu(input, self.lower, self.upper, self.training, self.inplace) @@ -143,7 +143,7 @@ class RReLU(Module): return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) -@torch._jit_internal.weak_module +@weak_module class Hardtanh(Module): r"""Applies the HardTanh function element-wise @@ -196,7 +196,7 @@ class Hardtanh(Module): self.inplace = inplace assert self.max_val > self.min_val - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.hardtanh(input, self.min_val, self.max_val, self.inplace) @@ -207,7 +207,7 @@ class Hardtanh(Module): ) -@torch._jit_internal.weak_module +@weak_module class ReLU6(Hardtanh): r"""Applies the element-wise function: @@ -239,7 +239,7 @@ class ReLU6(Hardtanh): return inplace_str -@torch._jit_internal.weak_module +@weak_module class Sigmoid(Module): r"""Applies the element-wise function: @@ -261,12 +261,12 @@ class Sigmoid(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return torch.sigmoid(input) -@torch._jit_internal.weak_module +@weak_module class Tanh(Module): r"""Applies the element-wise function: @@ -287,12 +287,12 @@ class Tanh(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return torch.tanh(input) -@torch._jit_internal.weak_module +@weak_module class ELU(Module): r"""Applies the element-wise function: @@ -323,7 +323,7 @@ class ELU(Module): self.alpha = alpha self.inplace = inplace - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.elu(input, self.alpha, self.inplace) @@ -332,7 +332,7 @@ class ELU(Module): return 'alpha={}{}'.format(self.alpha, inplace_str) -@torch._jit_internal.weak_module +@weak_module class CELU(Module): r"""Applies the element-wise function: @@ -368,7 +368,7 @@ class CELU(Module): self.alpha = alpha self.inplace = inplace - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.celu(input, self.alpha, self.inplace) @@ -377,7 +377,7 @@ class CELU(Module): return 'alpha={}{}'.format(self.alpha, inplace_str) -@torch._jit_internal.weak_module +@weak_module class SELU(Module): r"""Applied element-wise, as: @@ -413,7 +413,7 @@ class SELU(Module): super(SELU, self).__init__() self.inplace = inplace - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.selu(input, self.inplace) @@ -422,7 +422,7 @@ class SELU(Module): return inplace_str -@torch._jit_internal.weak_module +@weak_module class GLU(Module): r"""Applies the gated linear unit function :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half @@ -448,7 +448,7 @@ class GLU(Module): super(GLU, self).__init__() self.dim = dim - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.glu(input, self.dim) @@ -456,7 +456,7 @@ class GLU(Module): return 'dim={}'.format(self.dim) -@torch._jit_internal.weak_module +@weak_module class Hardshrink(Module): r"""Applies the hard shrinkage function element-wise: @@ -490,7 +490,7 @@ class Hardshrink(Module): super(Hardshrink, self).__init__() self.lambd = lambd - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.hardshrink(input, self.lambd) @@ -498,7 +498,7 @@ class Hardshrink(Module): return '{}'.format(self.lambd) -@torch._jit_internal.weak_module +@weak_module class LeakyReLU(Module): r"""Applies the element-wise function: @@ -539,7 +539,7 @@ class LeakyReLU(Module): self.negative_slope = negative_slope self.inplace = inplace - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.leaky_relu(input, self.negative_slope, self.inplace) @@ -548,7 +548,7 @@ class LeakyReLU(Module): return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) -@torch._jit_internal.weak_module +@weak_module class LogSigmoid(Module): r"""Applies the element-wise function: @@ -568,12 +568,12 @@ class LogSigmoid(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.logsigmoid(input) -@torch._jit_internal.weak_module +@weak_module class Softplus(Module): r"""Applies the element-wise function: @@ -610,7 +610,7 @@ class Softplus(Module): self.beta = beta self.threshold = threshold - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.softplus(input, self.beta, self.threshold) @@ -618,7 +618,7 @@ class Softplus(Module): return 'beta={}, threshold={}'.format(self.beta, self.threshold) -@torch._jit_internal.weak_module +@weak_module class Softshrink(Module): r"""Applies the soft shrinkage function elementwise: @@ -652,7 +652,7 @@ class Softshrink(Module): super(Softshrink, self).__init__() self.lambd = lambd - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.softshrink(input, self.lambd) @@ -660,7 +660,7 @@ class Softshrink(Module): return str(self.lambd) -@torch._jit_internal.weak_module +@weak_module class PReLU(Module): r"""Applies the element-wise function: @@ -717,7 +717,7 @@ class PReLU(Module): super(PReLU, self).__init__() self.weight = Parameter(torch.Tensor(num_parameters).fill_(init)) - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.prelu(input, self.weight) @@ -725,7 +725,7 @@ class PReLU(Module): return 'num_parameters={}'.format(self.num_parameters) -@torch._jit_internal.weak_module +@weak_module class Softsign(Module): r"""Applies the element-wise function: @@ -746,12 +746,12 @@ class Softsign(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.softsign(input) -@torch._jit_internal.weak_module +@weak_module class Tanhshrink(Module): r"""Applies the element-wise function: @@ -772,12 +772,12 @@ class Tanhshrink(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.tanhshrink(input) -@torch._jit_internal.weak_module +@weak_module class Softmin(Module): r"""Applies the Softmin function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor @@ -810,12 +810,12 @@ class Softmin(Module): super(Softmin, self).__init__() self.dim = dim - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.softmin(input, self.dim, _stacklevel=5) -@torch._jit_internal.weak_module +@weak_module class Softmax(Module): r"""Applies the Softmax function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor @@ -860,12 +860,12 @@ class Softmax(Module): if not hasattr(self, 'dim'): self.dim = None - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.softmax(input, self.dim, _stacklevel=5) -@torch._jit_internal.weak_module +@weak_module class Softmax2d(Module): r"""Applies SoftMax over features to each spatial location. @@ -888,13 +888,13 @@ class Softmax2d(Module): >>> output = m(input) """ - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' return F.softmax(input, 1, _stacklevel=5) -@torch._jit_internal.weak_module +@weak_module class LogSoftmax(Module): r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor. The LogSoftmax formulation can be simplified as: @@ -931,6 +931,6 @@ class LogSoftmax(Module): if not hasattr(self, 'dim'): self.dim = None - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, input): return F.log_softmax(input, self.dim, _stacklevel=5) diff --git a/torch/nn/modules/distance.py b/torch/nn/modules/distance.py index 2ce9c1337f..45e4ae7280 100644 --- a/torch/nn/modules/distance.py +++ b/torch/nn/modules/distance.py @@ -4,7 +4,7 @@ from .. import functional as F from ..._jit_internal import weak_module, weak_script_method -@torch._jit_internal.weak_module +@weak_module class PairwiseDistance(Module): r""" Computes the batchwise pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm: @@ -39,11 +39,12 @@ class PairwiseDistance(Module): self.eps = eps self.keepdim = keepdim - @torch._jit_internal.weak_script_method + @weak_script_method def forward(self, x1, x2): return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) +@weak_module class CosineSimilarity(Module): r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim. @@ -67,10 +68,13 @@ class CosineSimilarity(Module): >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) >>> output = cos(input1, input2) """ + __constants__ = ['dim', 'eps'] + def __init__(self, dim=1, eps=1e-8): super(CosineSimilarity, self).__init__() self.dim = dim self.eps = eps + @weak_script_method def forward(self, x1, x2): return F.cosine_similarity(x1, x2, self.dim, self.eps) diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index 829e1ce032..d5b8427a36 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -1,13 +1,18 @@ from .batchnorm import _BatchNorm from .. import functional as F +from ..._jit_internal import weak_module, weak_script_method class _InstanceNorm(_BatchNorm): + __constants__ = ['running_mean', 'running_var', 'weight', 'bias', + 'training', 'track_running_stats', 'momentum', 'eps'] + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False): super(_InstanceNorm, self).__init__( num_features, eps, momentum, affine, track_running_stats) + @weak_script_method def _check_input_dim(self, input): raise NotImplementedError @@ -41,6 +46,7 @@ class _InstanceNorm(_BatchNorm): state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + @weak_script_method def forward(self, input): self._check_input_dim(input) @@ -49,6 +55,7 @@ class _InstanceNorm(_BatchNorm): self.training or not self.track_running_stats, self.momentum, self.eps) +@weak_module class InstanceNorm1d(_InstanceNorm): r"""Applies Instance Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension) as described in the paper @@ -117,12 +124,14 @@ class InstanceNorm1d(_InstanceNorm): https://arxiv.org/abs/1607.08022 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) +@weak_module class InstanceNorm2d(_InstanceNorm): r"""Applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper @@ -191,12 +200,14 @@ class InstanceNorm2d(_InstanceNorm): https://arxiv.org/abs/1607.08022 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) +@weak_module class InstanceNorm3d(_InstanceNorm): r"""Applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper @@ -265,6 +276,7 @@ class InstanceNorm3d(_InstanceNorm): https://arxiv.org/abs/1607.08022 """ + @weak_script_method def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)' diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 3cd4661e5a..3ed374f7b8 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -105,6 +105,7 @@ class Bilinear(Module): >>> print(output.size()) torch.Size([128, 40]) """ + __constants__ = ['in1_features', 'in2_features', 'out_features'] def __init__(self, in1_features, in2_features, out_features, bias=True): super(Bilinear, self).__init__() diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index b3bf05f9d5..b7308214ac 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -5,8 +5,10 @@ from .module import Module from .batchnorm import _BatchNorm from .. import functional as F from .. import init +from ..._jit_internal import weak_module, weak_script_method +@weak_module class LocalResponseNorm(Module): r"""Applies local response normalization over an input signal composed of several input planes, where channels occupy the second dimension. @@ -35,14 +37,16 @@ class LocalResponseNorm(Module): >>> output_4d = lrn(signal_4d) """ + __constants__ = ['size', 'alpha', 'beta', 'k'] - def __init__(self, size, alpha=1e-4, beta=0.75, k=1): + def __init__(self, size, alpha=1e-4, beta=0.75, k=1.): super(LocalResponseNorm, self).__init__() self.size = size self.alpha = alpha self.beta = beta self.k = k + @weak_script_method def forward(self, input): return F.local_response_norm(input, self.size, self.alpha, self.beta, self.k) @@ -68,6 +72,7 @@ class CrossMapLRN2d(Module): return '{size}, alpha={alpha}, beta={beta}, k={k}'.format(**self.__dict__) +@weak_module class LayerNorm(Module): r"""Applies Layer Normalization over a mini-batch of inputs as described in the paper `Layer Normalization`_ . @@ -145,6 +150,7 @@ class LayerNorm(Module): init.ones_(self.weight) init.zeros_(self.bias) + @weak_script_method def forward(self, input): return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index b9d63a084d..9ff8d59cb0 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -1,17 +1,21 @@ from .module import Module from .utils import _pair, _quadruple, _ntuple from .. import functional as F +from ..._jit_internal import weak_module, weak_script_method # TODO: grad_output size asserts in THNN +@weak_module class _ConstantPadNd(Module): + __constants__ = ['padding', 'value'] def __init__(self, value): super(_ConstantPadNd, self).__init__() self.value = value + @weak_script_method def forward(self, input): return F.pad(input, self.padding, 'constant', self.value) @@ -19,6 +23,7 @@ class _ConstantPadNd(Module): return 'padding={}, value={}'.format(self.padding, self.value) +@weak_module class ConstantPad1d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. @@ -67,6 +72,7 @@ class ConstantPad1d(_ConstantPadNd): self.padding = _pair(padding) +@weak_module class ConstantPad2d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. @@ -114,12 +120,14 @@ class ConstantPad2d(_ConstantPadNd): [ 3.5000, 3.5000, 3.5000, 3.5000, 3.5000]]]) """ + __constants__ = ['padding', 'value'] def __init__(self, padding, value): super(ConstantPad2d, self).__init__(value) self.padding = _quadruple(padding) +@weak_module class ConstantPad3d(_ConstantPadNd): r"""Pads the input tensor boundaries with a constant value. @@ -155,8 +163,11 @@ class ConstantPad3d(_ConstantPadNd): self.padding = _ntuple(6)(padding) +@weak_module class _ReflectionPadNd(Module): + __constants__ = ['padding'] + @weak_script_method def forward(self, input): return F.pad(input, self.padding, 'reflect') @@ -164,6 +175,7 @@ class _ReflectionPadNd(Module): return '{}'.format(self.padding) +@weak_module class ReflectionPad1d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. @@ -205,6 +217,7 @@ class ReflectionPad1d(_ReflectionPadNd): self.padding = _pair(padding) +@weak_module class ReflectionPad2d(_ReflectionPadNd): r"""Pads the input tensor using the reflection of the input boundary. @@ -254,8 +267,11 @@ class ReflectionPad2d(_ReflectionPadNd): self.padding = _quadruple(padding) +@weak_module class _ReplicationPadNd(Module): + __constants__ = ['padding'] + @weak_script_method def forward(self, input): return F.pad(input, self.padding, 'replicate') @@ -263,6 +279,7 @@ class _ReplicationPadNd(Module): return '{}'.format(self.padding) +@weak_module class ReplicationPad1d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. @@ -301,6 +318,7 @@ class ReplicationPad1d(_ReplicationPadNd): self.padding = _pair(padding) +@weak_module class ReplicationPad2d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. @@ -349,6 +367,7 @@ class ReplicationPad2d(_ReplicationPadNd): self.padding = _quadruple(padding) +@weak_module class ReplicationPad3d(_ReplicationPadNd): r"""Pads the input tensor using replication of the input boundary. @@ -384,6 +403,7 @@ class ReplicationPad3d(_ReplicationPadNd): self.padding = _ntuple(6)(padding) +@weak_module class ZeroPad2d(ConstantPad2d): r"""Pads the input tensor boundaries with zero. @@ -428,4 +448,4 @@ class ZeroPad2d(ConstantPad2d): """ def __init__(self, padding): - super(ZeroPad2d, self).__init__(padding, 0) + super(ZeroPad2d, self).__init__(padding, 0.)