Add (Un)Fold modules to standard library (#14759)
authorDavid Riazati <davidriazati@fb.com>
Tue, 18 Dec 2018 19:43:45 +0000 (11:43 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 20:03:08 +0000 (12:03 -0800)
Summary:
Depends on #14597 for the corresponding aten ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14759

Differential Revision: D13325356

Pulled By: driazati

fbshipit-source-id: 99e39449c1ccfa293de05672c31a11e580bdd11f

test/test_jit.py
torch/csrc/jit/register_special_ops.cpp
torch/jit/__init__.py
torch/nn/functional.py
torch/nn/modules/fold.py

index 89ba4cc..8d88e8c 100644 (file)
@@ -10635,6 +10635,8 @@ EXCLUDE_MODULE_EXPORT_IMPORT = {
     'MaxPool3d',
     'AdaptiveAvgPool2d',
     'AdaptiveAvgPool3d',
+    'Fold',
+    'Unfold',
 }
 
 # NB: JIT script tests for all nn functional interfaces, script mode does
index 9ae08a5..e467823 100644 (file)
@@ -115,6 +115,15 @@ RegisterOperators reg({
           return 0;
         };
       }),
+    Operator(
+      "aten::_assert_int_or_pair(int[] vals, str name, str message) -> Tensor",
+      [](const Node* node) {
+        return [](Stack& stack) {
+          // Everything is a list at the point this is used, so don't do anything
+          drop(stack, 3);
+          return 0;
+        };
+      }),
 
 });
 }
index dc1499f..ab06a34 100644 (file)
@@ -1441,8 +1441,7 @@ def _get_builtin_table():
     _builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
     _builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"
     _builtin_table[id(torch.nn.functional.upsample_bilinear)] = "aten::__upsample_bilinear"
-    _builtin_table[id(torch.nn.functional.fold)] = "aten::fold"
-    _builtin_table[id(torch.nn.functional.unfold)] = "aten::unfold"
+    _builtin_table[id(torch.nn.functional.assert_int_or_pair)] = "aten::_assert_int_or_pair"
 
     return _builtin_table
 
index b789dde..72c5c27 100644 (file)
@@ -2826,7 +2826,9 @@ def assert_int_or_pair(arg, arg_name, message):
     assert isinstance(arg, int) or len(arg) == 2, message.format(arg_name)
 
 
+@torch._jit_internal.weak_script
 def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
+    # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor  # noqa
     r"""Extracts sliding local blocks from an batched input tensor.
 
     .. warning::
@@ -2843,13 +2845,17 @@ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
         assert_int_or_pair(padding, 'padding', msg)
         assert_int_or_pair(stride, 'stride', msg)
 
-        return torch._C._nn.thnn_im2col(input, _pair(kernel_size),
-                                        _pair(dilation), _pair(padding), _pair(stride))
+        ret = torch._C._nn.thnn_im2col(input, _pair(kernel_size),
+                                       _pair(dilation), _pair(padding), _pair(stride))
     else:
         raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim()))
+        ret = input  # TODO: remove when jit supports exception control flow
+    return ret
 
 
+@torch._jit_internal.weak_script
 def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
+    # type: (Tensor, BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int], BroadcastingList2[int]) -> Tensor  # noqa
     r"""Combines an array of sliding local blocks into a large containing
     tensor.
 
@@ -2867,7 +2873,9 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
         assert_int_or_pair(padding, 'padding', msg)
         assert_int_or_pair(stride, 'stride', msg)
 
-        return torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size),
-                                        _pair(dilation), _pair(padding), _pair(stride))
+        ret = torch._C._nn.thnn_col2im(input, _pair(output_size), _pair(kernel_size),
+                                       _pair(dilation), _pair(padding), _pair(stride))
     else:
         raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim()))
+        ret = input  # TODO: remove when jit supports exception control flow
+    return ret
index d003582..03adaef 100644 (file)
@@ -1,8 +1,10 @@
 # coding=utf-8
 from .module import Module
 from .. import functional as F
+from ..._jit_internal import weak_module, weak_script_method
 
 
+@weak_module
 class Fold(Module):
     r"""Combines an array of sliding local blocks into a large containing
     tensor.
@@ -87,6 +89,8 @@ class Fold(Module):
         https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
 
     """
+    __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding',
+                     'stride']
 
     def __init__(self, output_size, kernel_size, dilation=1, padding=0, stride=1):
         super(Fold, self).__init__()
@@ -96,6 +100,7 @@ class Fold(Module):
         self.padding = padding
         self.stride = stride
 
+    @weak_script_method
     def forward(self, input):
         return F.fold(input, self.output_size, self.kernel_size, self.dilation,
                       self.padding, self.stride)
@@ -107,6 +112,7 @@ class Fold(Module):
             )
 
 
+@weak_module
 class Unfold(Module):
     r"""Extracts sliding local blocks from a batched input tensor.
 
@@ -201,6 +207,7 @@ class Unfold(Module):
         https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
 
     """
+    __constants__ = ['kernel_size', 'dilation', 'padding', 'stride']
 
     def __init__(self, kernel_size, dilation=1, padding=0, stride=1):
         super(Unfold, self).__init__()
@@ -209,6 +216,7 @@ class Unfold(Module):
         self.padding = padding
         self.stride = stride
 
+    @weak_script_method
     def forward(self, input):
         return F.unfold(input, self.kernel_size, self.dilation,
                         self.padding, self.stride)