Remove fully qualified weak script names (#15364)
authorDavid Riazati <davidriazati@fb.com>
Wed, 19 Dec 2018 00:44:04 +0000 (16:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 00:48:52 +0000 (16:48 -0800)
Summary:
Cleanup to make references to `weak_script` consistent across codebase
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15364

Differential Revision: D13509676

Pulled By: driazati

fbshipit-source-id: 93dbbbe57e9b9b6587895f3cc6fac678babd21de

torch/nn/_functions/vision.py
torch/nn/functional.py

index cccf011..159025e 100644 (file)
@@ -1,8 +1,9 @@
 import torch
 import torch.backends.cudnn as cudnn
+from ..._jit_internal import weak_script
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def affine_grid_generator(theta, size):
     # type: (Tensor, List[int]) -> Tensor
     if theta.is_cuda and cudnn.enabled and cudnn.is_acceptable(theta) and len(size) == 4:
index 72c5c27..f56ca3c 100644 (file)
@@ -300,7 +300,7 @@ Args:
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
                                        output_ratio=None, return_indices=False,
                                        _random_samples=None):
@@ -351,7 +351,7 @@ def fractional_max_pool2d_with_indices(input, kernel_size, output_size=None,
     return torch._C._nn.fractional_max_pool2d(input, kernel_size, _output_size, _random_samples)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _fractional_max_pool2d(input, kernel_size, output_size=None,
                            output_ratio=None, return_indices=False,
                            _random_samples=None):
@@ -368,7 +368,7 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
     if_false=_fractional_max_pool2d)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
                             dilation=1, ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor]  # noqa
@@ -385,7 +385,7 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
         input, kernel_size, _stride, padding, dilation, ceil_mode)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor  # noqa
@@ -400,7 +400,7 @@ max_pool1d = torch._jit_internal.boolean_dispatch(
     if_false=_max_pool1d)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
                             ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor]  # noqa
@@ -416,7 +416,7 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation
     return torch._C._nn.max_pool2d_with_indices(input, kernel_size, _stride, padding, dilation, ceil_mode)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor  # noqa
@@ -431,7 +431,7 @@ max_pool2d = torch._jit_internal.boolean_dispatch(
     if_false=_max_pool2d)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
                             dilation=1, ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor]  # noqa
@@ -448,7 +448,7 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
         input, kernel_size, _stride, padding, dilation, ceil_mode)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
     # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor  # noqa
@@ -463,7 +463,7 @@ max_pool3d = torch._jit_internal.boolean_dispatch(
     if_false=_max_pool3d)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _unpool_output_size(input, kernel_size, stride, padding, output_size):
     # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int]
     input_size = input.size()
@@ -494,7 +494,7 @@ def _unpool_output_size(input, kernel_size, stride, padding, output_size):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
                  output_size=None):
     # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor  # noqa
@@ -513,7 +513,7 @@ def max_unpool1d(input, indices, kernel_size, stride=None, padding=0,
     return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size + [1]).squeeze(3)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
                  output_size=None):
     # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor  # noqa
@@ -532,7 +532,7 @@ def max_unpool2d(input, indices, kernel_size, stride=None, padding=0,
     return torch._C._nn.max_unpool2d(input, indices, output_size)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
                  output_size=None):
     # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor  # noqa
@@ -552,7 +552,7 @@ def max_unpool3d(input, indices, kernel_size, stride=None, padding=0,
         input, indices, output_size, _stride, padding)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
     # type: (Tensor, float, int, Optional[BroadcastingList2[int]], bool) -> Tensor
     r"""Applies a 2D power-average pooling over an input signal composed of
@@ -571,7 +571,7 @@ def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
     return (torch.sign(out) * relu(torch.abs(out))).mul(kw * kh).pow(1. / norm_type)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
     # type: (Tensor, float, int, Optional[BroadcastingList1[int]], bool) -> Tensor
     r"""Applies a 1D power-average pooling over an input signal composed of
@@ -589,7 +589,7 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
     return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
     r"""Applies a 1D adaptive max pooling over an input signal composed of
@@ -604,7 +604,7 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
     return torch.adaptive_max_pool1d(input, output_size)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _adaptive_max_pool1d(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
     return adaptive_max_pool1d_with_indices(input, output_size)[0]
@@ -617,7 +617,7 @@ adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
     if_false=_adaptive_max_pool1d)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
     r"""Applies a 2D adaptive max pooling over an input signal composed of
@@ -634,7 +634,7 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
     return torch._C._nn.adaptive_max_pool2d(input, output_size)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _adaptive_max_pool2d(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
     return adaptive_max_pool2d_with_indices(input, output_size)[0]
@@ -647,7 +647,7 @@ adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
     if_false=_adaptive_max_pool2d)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
     r"""Applies a 3D adaptive max pooling over an input signal composed of
@@ -664,7 +664,7 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
     return torch._C._nn.adaptive_max_pool3d(input, output_size)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _adaptive_max_pool3d(input, output_size, return_indices=False):
     # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
     return adaptive_max_pool3d_with_indices(input, output_size)[0]
@@ -690,7 +690,7 @@ Args:
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def adaptive_avg_pool2d(input, output_size):
     # type: (Tensor, BroadcastingList2[int]) -> Tensor
     r"""
@@ -707,7 +707,7 @@ def adaptive_avg_pool2d(input, output_size):
     return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def adaptive_avg_pool3d(input, output_size):
     # type: (Tensor, BroadcastingList3[int]) -> Tensor
     r"""
@@ -725,7 +725,7 @@ def adaptive_avg_pool3d(input, output_size):
 
 
 # Activation functions
-@torch._jit_internal.weak_script
+@weak_script
 def dropout(input, p=0.5, training=True, inplace=False):
     # type: (Tensor, float, bool, bool) -> Tensor
     r"""
@@ -748,7 +748,7 @@ def dropout(input, p=0.5, training=True, inplace=False):
             else _VF.dropout(input, p, training))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def alpha_dropout(input, p=0.5, training=False, inplace=False):
     # type: (Tensor, float, bool, bool) -> Tensor
     r"""Applies alpha dropout to the input.
@@ -763,7 +763,7 @@ def alpha_dropout(input, p=0.5, training=False, inplace=False):
             else _VF.alpha_dropout(input, p, training))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def dropout2d(input, p=0.5, training=True, inplace=False):
     # type: (Tensor, float, bool, bool) -> Tensor
     r"""
@@ -788,7 +788,7 @@ def dropout2d(input, p=0.5, training=True, inplace=False):
             else _VF.feature_dropout(input, p, training))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def dropout3d(input, p=0.5, training=True, inplace=False):
     # type: (Tensor, float, bool, bool) -> Tensor
     r"""
@@ -815,7 +815,7 @@ def dropout3d(input, p=0.5, training=True, inplace=False):
             else _VF.feature_dropout(input, p, training))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
     # type: (Tensor, float, bool, bool) -> Tensor
     if p < 0. or p > 1.:
@@ -826,7 +826,7 @@ def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
             else _VF.feature_alpha_dropout(input, p, training))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def threshold(input, threshold, value, inplace=False):
     # type: (Tensor, float, float, bool) -> Tensor
     r"""Thresholds each element of the input Tensor.
@@ -847,7 +847,7 @@ In-place version of :func:`~threshold`.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def relu(input, inplace=False):
     # type: (Tensor, bool) -> Tensor
     r"""relu(input, inplace=False) -> Tensor
@@ -869,7 +869,7 @@ In-place version of :func:`~relu`.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def glu(input, dim=-1):
     # type: (Tensor, int) -> Tensor
     r"""
@@ -894,7 +894,7 @@ def glu(input, dim=-1):
     return torch._C._nn.glu(input, dim)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def hardtanh(input, min_val=-1., max_val=1., inplace=False):
     # type: (Tensor, float, float, bool) -> Tensor
     r"""
@@ -917,7 +917,7 @@ In-place version of :func:`~hardtanh`.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def relu6(input, inplace=False):
     # type: (Tensor, bool) -> Tensor
     r"""relu6(input, inplace=False) -> Tensor
@@ -929,7 +929,7 @@ def relu6(input, inplace=False):
     return hardtanh(input, 0., 6., inplace)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def elu(input, alpha=1., inplace=False):
     # type: (Tensor, float, bool) -> Tensor
     r"""Applies element-wise,
@@ -951,7 +951,7 @@ In-place version of :func:`~elu`.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def selu(input, inplace=False):
     # type: (Tensor, bool) -> Tensor
     r"""selu(input, inplace=False) -> Tensor
@@ -977,7 +977,7 @@ In-place version of :func:`~selu`.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def celu(input, alpha=1., inplace=False):
     # type: (Tensor, float, bool) -> Tensor
     r"""celu(input, alpha=1., inplace=False) -> Tensor
@@ -1000,7 +1000,7 @@ In-place version of :func:`~celu`.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def leaky_relu(input, negative_slope=0.01, inplace=False):
     # type: (Tensor, float, bool) -> Tensor
     r"""
@@ -1025,7 +1025,7 @@ In-place version of :func:`~leaky_relu`.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def prelu(input, weight):
     # type: (Tensor, Tensor) -> Tensor
     r"""prelu(input, weight) -> Tensor
@@ -1039,7 +1039,7 @@ def prelu(input, weight):
     return torch.prelu(input, weight)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
     # type: (Tensor, float, float, bool, bool) -> Tensor
     r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor
@@ -1070,7 +1070,7 @@ See :class:`~torch.nn.LogSigmoid` for more details.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def hardshrink(input, lambd=0.5):
     # type: (Tensor, float) -> Tensor
     r"""
@@ -1083,7 +1083,7 @@ def hardshrink(input, lambd=0.5):
     return torch.hardshrink(input, lambd)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def tanhshrink(input):
     r"""tanhshrink(input) -> Tensor
 
@@ -1094,7 +1094,7 @@ def tanhshrink(input):
     return input - input.tanh()
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def softsign(input):
     r"""softsign(input) -> Tensor
 
@@ -1110,7 +1110,7 @@ softplus(input, beta=1, threshold=20) -> Tensor
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _get_softmax_dim(name, ndim, stacklevel):
     # type: (str, int, int) -> int
     warnings.warn("Implicit dimension choice for {} has been deprecated. "
@@ -1122,7 +1122,7 @@ def _get_softmax_dim(name, ndim, stacklevel):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def softmin(input, dim=None, _stacklevel=3, dtype=None):
     # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
     r"""Applies a softmin function.
@@ -1151,7 +1151,7 @@ def softmin(input, dim=None, _stacklevel=3, dtype=None):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def softmax(input, dim=None, _stacklevel=3, dtype=None):
     # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
     r"""Applies a softmax function.
@@ -1191,7 +1191,7 @@ def softmax(input, dim=None, _stacklevel=3, dtype=None):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _sample_gumbel(shape, eps=1e-10, out=None):
     # type: (List[int], float, Optional[Tensor]) -> Tensor
     """
@@ -1208,7 +1208,7 @@ def _sample_gumbel(shape, eps=1e-10, out=None):
     return - torch.log(eps - torch.log(U + eps))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
     # type: (Tensor, float, float) -> Tensor
     """
@@ -1224,7 +1224,7 @@ def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
     return softmax(y / tau, dims - 1)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def gumbel_softmax(logits, tau=1., hard=False, eps=1e-10):
     # type: (Tensor, float, bool, float) -> Tensor
     r"""
@@ -1268,7 +1268,7 @@ def gumbel_softmax(logits, tau=1., hard=False, eps=1e-10):
     return y
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def log_softmax(input, dim=None, _stacklevel=3, dtype=None):
     # type: (Tensor, Optional[int], int, Optional[int]) -> Tensor
     r"""Applies a softmax followed by a logarithm.
@@ -1307,7 +1307,7 @@ See :class:`~torch.nn.Softshrink` for more details.
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def tanh(input):
     r"""tanh(input) -> Tensor
 
@@ -1320,7 +1320,7 @@ def tanh(input):
     return input.tanh()
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def sigmoid(input):
     r"""sigmoid(input) -> Tensor
 
@@ -1332,7 +1332,7 @@ def sigmoid(input):
     return input.sigmoid()
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def linear(input, weight, bias=None):
     # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
     r"""
@@ -1357,7 +1357,7 @@ def linear(input, weight, bias=None):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def bilinear(input1, input2, weight, bias=None):
     # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tensor
     return torch.bilinear(input1, input2, weight, bias)
@@ -1369,7 +1369,7 @@ def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
         return torch.embedding_renorm_(weight, input, max_norm, norm_type)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
               scale_grad_by_freq=False, sparse=False):
     # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
@@ -1453,7 +1453,7 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
                   scale_grad_by_freq=False, mode='mean', sparse=False):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor
@@ -1591,7 +1591,7 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def batch_norm(input, running_mean, running_var, weight=None, bias=None,
                training=False, momentum=0.1, eps=1e-5):
     # type: (Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor
@@ -1623,7 +1623,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None,
     )
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def instance_norm(input, running_mean=None, running_var=None, weight=None,
                   bias=None, use_input_stats=True, momentum=0.1, eps=1e-5):
     # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor  # noqa
@@ -1639,7 +1639,7 @@ def instance_norm(input, running_mean=None, running_var=None, weight=None,
     )
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
     # type: (Tensor, List[int], Optional[Tensor], Optional[Tensor], float) -> Tensor
     r"""Applies Layer Normalization for last certain number of dimensions.
@@ -1650,7 +1650,7 @@ def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
                             torch.backends.cudnn.enabled)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
     # type: (Tensor, int, Optional[Tensor], Optional[Tensor], float) -> Tensor
     r"""Applies Group Normalization for last certain number of dimensions.
@@ -1661,7 +1661,7 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
                             torch.backends.cudnn.enabled)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
     # type: (Tensor, int, float, float, float) -> Tensor
     r"""Applies local response normalization over an input signal composed of
@@ -1690,7 +1690,7 @@ def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
 
 # loss
 
-@torch._jit_internal.weak_script
+@weak_script
 def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
              reduction='mean'):
     # type: (Tensor, Tensor, Tensor, Tensor, int, str) -> Tensor
@@ -1731,7 +1731,7 @@ def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0,
     return torch.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, _Reduction.get_enum(reduction))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
              reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
@@ -1810,7 +1810,7 @@ def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-8,
                      reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, bool, bool, Optional[bool], float, Optional[bool], str) -> Tensor
@@ -1864,7 +1864,7 @@ def poisson_nll_loss(input, target, log_input=True, full=False, size_average=Non
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""The `Kullback-Leibler divergence`_ Loss.
@@ -1920,7 +1920,7 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
     return reduced
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
                   reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], int, Optional[bool], str) -> Tensor
@@ -1969,7 +1969,7 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1
     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def binary_cross_entropy(input, target, weight=None, size_average=None,
                          reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@@ -2026,7 +2026,7 @@ def binary_cross_entropy(input, target, weight=None, size_average=None,
         input, target, weight, reduction_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
                                      reduce=None, reduction='mean', pos_weight=None):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
@@ -2087,14 +2087,14 @@ def _pointwise_loss(lambd, lambd_optimized, input, target, reduction='mean'):
         return lambd_optimized(expanded_input, expanded_target, _Reduction.get_enum(reduction))
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def _smooth_l1_loss(input, target):
     # type: (Tensor, Tensor) -> Tensor
     t = torch.abs(input - target)
     return torch.where(t < 1, 0.5 * t ** 2, t - 0.5)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""Function that uses a squared term if the absolute
@@ -2114,7 +2114,7 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""l1_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@@ -2135,7 +2135,7 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@@ -2156,7 +2156,7 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
                         reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@@ -2174,7 +2174,7 @@ def margin_ranking_loss(input1, input2, target, margin=0, size_average=None,
     return torch.margin_ranking_loss(input1, input2, target, margin, reduction_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
                          reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@@ -2189,7 +2189,7 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=None,
     return torch.hinge_embedding_loss(input, target, margin, reduction_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@@ -2203,7 +2203,7 @@ def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduct
     return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[bool], Optional[bool], str) -> Tensor
     r"""soft_margin_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
@@ -2217,7 +2217,7 @@ def soft_margin_loss(input, target, size_average=None, reduce=None, reduction='m
     return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
                                 reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@@ -2247,7 +2247,7 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None,
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
                           reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, Tensor, float, Optional[bool], Optional[bool], str) -> Tensor
@@ -2262,7 +2262,7 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None,
     return torch.cosine_embedding_loss(input1, input2, target, margin, reduction_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None,
                       reduce=None, reduction='mean'):
     # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor
@@ -2526,7 +2526,7 @@ GRID_SAMPLE_PADDING_MODES = {
 }
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
     # type: (Tensor, Tensor, str, str) -> Tensor
     r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
@@ -2606,7 +2606,7 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
     return torch.grid_sampler(input, grid, mode_enum, padding_mode_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def affine_grid(theta, size):
     # type: (Tensor, List[int]) -> Tensor
     r"""Generates a 2d flow field, given a batch of affine matrices :attr:`theta`
@@ -2624,7 +2624,7 @@ def affine_grid(theta, size):
     return vision.affine_grid_generator(theta, size)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def pad(input, pad, mode='constant', value=0):
     # type: (Tensor, List[int], str, float) -> Tensor
     r"""Pads tensor.
@@ -2716,7 +2716,7 @@ def pad(input, pad, mode='constant', value=0):
 # distance
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def pairwise_distance(x1, x2, p=2., eps=1e-6, keepdim=False):
     # type: (Tensor, Tensor, float, float, bool) -> Tensor
     r"""
@@ -2777,7 +2777,7 @@ Example::
 """)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
                         reduce=None, reduction="mean"):
     # type: (Tensor, Tensor, Tensor, float, float, float, bool, Optional[bool], Optional[bool], str) -> Tensor
@@ -2792,7 +2792,7 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
                                      swap, reduction_enum)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def normalize(input, p=2, dim=1, eps=1e-12, out=None):
     # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
     r"""Performs :math:`L_p` normalization of inputs over specified dimension.
@@ -2826,7 +2826,7 @@ def assert_int_or_pair(arg, arg_name, message):
     assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
     # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor  # noqa
     r"""Extracts sliding local blocks from an batched input tensor.
@@ -2853,7 +2853,7 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
     return ret
 
 
-@torch._jit_internal.weak_script
+@weak_script
 def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
     # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor  # noqa
     r"""Combines an array of sliding local blocks into a large containing