Add Pooling modules to Script (#14527)
authorDavid Riazati <davidriazati@fb.com>
Tue, 4 Dec 2018 07:49:39 +0000 (23:49 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 07:55:04 +0000 (23:55 -0800)
Summary:
Depends on #14584
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14527

Differential Revision: D13270773

Pulled By: driazati

fbshipit-source-id: e4acd43ccbce0f4b62d41c30ce8d5c721171e19a

test/test_jit.py
torch/nn/functional.py
torch/nn/modules/pooling.py

index 7ecbc95..4cb8c04 100644 (file)
@@ -8939,40 +8939,41 @@ a")
         self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
 
     def test_bool_dispatch(self):
-        def kwarg_false(x):
-            # type: (Tensor) -> Tensor
-            return F.max_pool1d(x, 1, 1, return_indices=False)
-        self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
+        with self.disableModuleHook():  # TODO: Python print broadcasting list
+            def kwarg_false(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1, return_indices=False)
+            self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
 
-        def kwarg_true(x):
-            # type: (Tensor) -> Tuple[Tensor, Tensor]
-            return F.max_pool1d(x, 1, 1, return_indices=True)
-        self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
+            def kwarg_true(x):
+                # type: (Tensor) -> Tuple[Tensor, Tensor]
+                return F.max_pool1d(x, 1, 1, return_indices=True)
+            self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
 
-        def full_kwarg_false(x):
-            # type: (Tensor) -> Tensor
-            return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
-        self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
+            def full_kwarg_false(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
+            self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
 
-        def full_kwarg_true(x):
-            # type: (Tensor) -> Tuple[Tensor, Tensor]
-            return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
-        self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
+            def full_kwarg_true(x):
+                # type: (Tensor) -> Tuple[Tensor, Tensor]
+                return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
+            self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
 
-        def use_default(x):
-            # type: (Tensor) -> Tensor
-            return F.max_pool1d(x, 1, 1)
-        self.checkScript(use_default, (torch.randn(3, 3, 3),))
+            def use_default(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1)
+            self.checkScript(use_default, (torch.randn(3, 3, 3),))
 
-        def arg_false(x):
-            # type: (Tensor) -> Tensor
-            return F.max_pool1d(x, 1, 1, 0, 1, False, False)
-        self.checkScript(arg_false, (torch.randn(3, 3, 3),))
+            def arg_false(x):
+                # type: (Tensor) -> Tensor
+                return F.max_pool1d(x, 1, 1, 0, 1, False, False)
+            self.checkScript(arg_false, (torch.randn(3, 3, 3),))
 
-        def arg_true(x):
-            # type: (Tensor) -> Tuple[Tensor, Tensor]
-            return F.max_pool1d(x, 1, 1, 0, 1, False, True)
-        self.checkScript(arg_true, (torch.randn(3, 3, 3),))
+            def arg_true(x):
+                # type: (Tensor) -> Tuple[Tensor, Tensor]
+                return F.max_pool1d(x, 1, 1, 0, 1, False, True)
+            self.checkScript(arg_true, (torch.randn(3, 3, 3),))
 
     def test_infer_size(self):
         from torch._C import _infer_size
@@ -9678,15 +9679,23 @@ EXCLUDE_SCRIPT = {
 }
 
 EXCLUDE_PYTHON_PRINT = {
+    # no support for BroadcastingList in python printer
     'test_nn_max_unpool1d',
     'test_nn_max_unpool2d',
     'test_nn_max_unpool3d',
+    'test_nn_max_pool3d',
+    'test_nn_max_pool2d',
+    'test_nn_max_pool3d'
 }
 
 EXCLUDE_SCRIPT_MODULES = {
     'test_nn_BatchNorm1d_not_tracking_stats',
     'test_nn_BatchNorm2d_not_tracking_stats',
     'test_nn_BatchNorm3d_not_tracking_stats',
+    'test_nn_AdaptiveAvgPool2d_tuple_none',
+    'test_nn_AdaptiveAvgPool3d_tuple_none',
+    'test_nn_AdaptiveMaxPool2d_tuple_none',
+    'test_nn_AdaptiveMaxPool3d_tuple_none',
 }
 
 DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
@@ -10206,6 +10215,11 @@ S = 5
 # module cannot be exported /imported currently
 EXCLUDE_MODULE_EXPORT_IMPORT = {
     'EmbeddingBag',
+    'MaxPool1d',
+    'MaxPool2d',
+    'MaxPool3d',
+    'AdaptiveAvgPool2d',
+    'AdaptiveAvgPool3d',
 }
 
 # NB: JIT script tests for all nn functional interfaces, script mode does
index 4313847..d690000 100644 (file)
@@ -372,7 +372,7 @@ fractional_max_pool2d = torch._jit_internal.boolean_dispatch(
 @torch._jit_internal.weak_script
 def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
                             dilation=1, ceil_mode=False, return_indices=False):
-    # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor]  # noqa
+    # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tuple[Tensor, Tensor]  # noqa
     r"""Applies a 1D max pooling over an input signal composed of several input
     planes.
 
@@ -389,7 +389,7 @@ def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
 @torch._jit_internal.weak_script
 def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
-    # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tensor
+    # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], BroadcastingList1[int], bool, bool) -> Tensor  # noqa
     return max_pool1d_with_indices(
         input, kernel_size, stride, padding, dilation, ceil_mode)[0]
 
@@ -404,7 +404,7 @@ max_pool1d = torch._jit_internal.boolean_dispatch(
 @torch._jit_internal.weak_script
 def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
                             ceil_mode=False, return_indices=False):
-    # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor]  # noqa
+    # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tuple[Tensor, Tensor]  # noqa
     r"""Applies a 2D max pooling over an input signal composed of several input
     planes.
 
@@ -420,7 +420,7 @@ def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation
 @torch._jit_internal.weak_script
 def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
-    # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tensor
+    # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], BroadcastingList2[int], bool, bool) -> Tensor  # noqa
     return max_pool2d_with_indices(
         input, kernel_size, stride, padding, dilation, ceil_mode)[0]
 
@@ -435,7 +435,7 @@ max_pool2d = torch._jit_internal.boolean_dispatch(
 @torch._jit_internal.weak_script
 def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
                             dilation=1, ceil_mode=False, return_indices=False):
-    # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor]  # noqa
+    # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tuple[Tensor, Tensor]  # noqa
     r"""Applies a 3D max pooling over an input signal composed of several input
     planes.
 
@@ -452,7 +452,7 @@ def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
 @torch._jit_internal.weak_script
 def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
                 ceil_mode=False, return_indices=False):
-    # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tensor
+    # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], BroadcastingList3[int], bool, bool) -> Tensor  # noqa
     return max_pool3d_with_indices(
         input, kernel_size, stride, padding, dilation, ceil_mode)[0]
 
index e4f3ade..3bde44b 100644 (file)
@@ -6,7 +6,10 @@ from .. import functional as F
 from ..._jit_internal import weak_module, weak_script_method
 
 
+@weak_module
 class _MaxPoolNd(Module):
+    __constants__ = ['kernel_size', 'stride', 'padding', 'dilation',
+                     'return_indices', 'ceil_mode']
 
     def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
                  return_indices=False, ceil_mode=False):
@@ -23,6 +26,7 @@ class _MaxPoolNd(Module):
             ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
 
 
+@weak_module
 class MaxPool1d(_MaxPoolNd):
     r"""Applies a 1D max pooling over an input signal composed of several input
     planes.
@@ -66,6 +70,7 @@ class MaxPool1d(_MaxPoolNd):
         https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.max_pool1d(input, self.kernel_size, self.stride,
                             self.padding, self.dilation, self.ceil_mode,
@@ -76,6 +81,7 @@ class MaxPool1d(_MaxPoolNd):
             ', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)
 
 
+@weak_module
 class MaxPool2d(_MaxPoolNd):
     r"""Applies a 2D max pooling over an input signal composed of several input
     planes.
@@ -135,12 +141,14 @@ class MaxPool2d(_MaxPoolNd):
         https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.max_pool2d(input, self.kernel_size, self.stride,
                             self.padding, self.dilation, self.ceil_mode,
                             self.return_indices)
 
 
+@weak_module
 class MaxPool3d(_MaxPoolNd):
     r"""Applies a 3D max pooling over an input signal composed of several input
     planes. This is not a test
@@ -204,12 +212,14 @@ class MaxPool3d(_MaxPoolNd):
         https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
     """  # noqa: E501
 
+    @weak_script_method
     def forward(self, input):
         return F.max_pool3d(input, self.kernel_size, self.stride,
                             self.padding, self.dilation, self.ceil_mode,
                             self.return_indices)
 
 
+@weak_module
 class _MaxUnpoolNd(Module):
 
     def extra_repr(self):
@@ -218,6 +228,7 @@ class _MaxUnpoolNd(Module):
         )
 
 
+@weak_module
 class MaxUnpool1d(_MaxUnpoolNd):
     r"""Computes a partial inverse of :class:`MaxPool1d`.
 
@@ -283,6 +294,7 @@ class MaxUnpool1d(_MaxUnpoolNd):
                               self.padding, output_size)
 
 
+@weak_module
 class MaxUnpool2d(_MaxUnpoolNd):
     r"""Computes a partial inverse of :class:`MaxPool2d`.
 
@@ -356,6 +368,7 @@ class MaxUnpool2d(_MaxUnpoolNd):
                               self.padding, output_size)
 
 
+@weak_module
 class MaxUnpool3d(_MaxUnpoolNd):
     r"""Computes a partial inverse of :class:`MaxPool3d`.
 
@@ -418,6 +431,7 @@ class MaxUnpool3d(_MaxUnpoolNd):
                               self.padding, output_size)
 
 
+@weak_module
 class _AvgPoolNd(Module):
 
     def extra_repr(self):
@@ -426,6 +440,7 @@ class _AvgPoolNd(Module):
         )
 
 
+@weak_module
 class AvgPool1d(_AvgPoolNd):
     r"""Applies a 1D average pooling over an input signal composed of several
     input planes.
@@ -467,6 +482,8 @@ class AvgPool1d(_AvgPoolNd):
         >>> m(torch.tensor([[[1.,2,3,4,5,6,7]]]))
         tensor([[[ 2.,  4.,  6.]]])
     """
+    __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
+                     'count_include_pad']
 
     def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                  count_include_pad=True):
@@ -477,12 +494,14 @@ class AvgPool1d(_AvgPoolNd):
         self.ceil_mode = ceil_mode
         self.count_include_pad = count_include_pad
 
+    @weak_script_method
     def forward(self, input):
         return F.avg_pool1d(
             input, self.kernel_size, self.stride, self.padding, self.ceil_mode,
             self.count_include_pad)
 
 
+@weak_module
 class AvgPool2d(_AvgPoolNd):
     r"""Applies a 2D average pooling over an input signal composed of several input
     planes.
@@ -533,6 +552,8 @@ class AvgPool2d(_AvgPoolNd):
         >>> input = torch.randn(20, 16, 50, 32)
         >>> output = m(input)
     """
+    __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
+                     'count_include_pad']
 
     def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                  count_include_pad=True):
@@ -543,11 +564,13 @@ class AvgPool2d(_AvgPoolNd):
         self.ceil_mode = ceil_mode
         self.count_include_pad = count_include_pad
 
+    @weak_script_method
     def forward(self, input):
         return F.avg_pool2d(input, self.kernel_size, self.stride,
                             self.padding, self.ceil_mode, self.count_include_pad)
 
 
+@weak_module
 class AvgPool3d(_AvgPoolNd):
     r"""Applies a 3D average pooling over an input signal composed of several input
     planes.
@@ -605,6 +628,8 @@ class AvgPool3d(_AvgPoolNd):
         >>> input = torch.randn(20, 16, 50,44, 31)
         >>> output = m(input)
     """
+    __constants__ = ['kernel_size', 'stride', 'padding', 'ceil_mode',
+                     'count_include_pad']
 
     def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
                  count_include_pad=True):
@@ -615,6 +640,7 @@ class AvgPool3d(_AvgPoolNd):
         self.ceil_mode = ceil_mode
         self.count_include_pad = count_include_pad
 
+    @weak_script_method
     def forward(self, input):
         return F.avg_pool3d(input, self.kernel_size, self.stride,
                             self.padding, self.ceil_mode, self.count_include_pad)
@@ -739,6 +765,7 @@ class LPPool1d(_LPPoolNd):
     """
 
     @weak_script_method
+    @weak_script_method
     def forward(self, input):
         return F.lp_pool1d(input, float(self.norm_type), self.kernel_size,
                            self.stride, self.ceil_mode)
@@ -800,7 +827,9 @@ class LPPool2d(_LPPoolNd):
                            self.stride, self.ceil_mode)
 
 
+@weak_module
 class _AdaptiveMaxPoolNd(Module):
+    __constants__ = ['output_size', 'return_indices']
 
     def __init__(self, output_size, return_indices=False):
         super(_AdaptiveMaxPoolNd, self).__init__()
@@ -814,6 +843,7 @@ class _AdaptiveMaxPoolNd(Module):
 #   output shapes are, and how the operation computes output.
 
 
+@weak_module
 class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
     r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
 
@@ -833,10 +863,12 @@ class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
 
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
 
 
+@weak_module
 class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
     r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
 
@@ -867,10 +899,12 @@ class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
 
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
 
 
+@weak_module
 class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
     r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
 
@@ -902,11 +936,14 @@ class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
 
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
 
 
+@weak_module
 class _AdaptiveAvgPoolNd(Module):
+    __constants__ = ['output_size']
 
     def __init__(self, output_size):
         super(_AdaptiveAvgPoolNd, self).__init__()
@@ -916,6 +953,7 @@ class _AdaptiveAvgPoolNd(Module):
         return 'output_size={}'.format(self.output_size)
 
 
+@weak_module
 class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
     r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
 
@@ -933,10 +971,12 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
 
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.adaptive_avg_pool1d(input, self.output_size)
 
 
+@weak_module
 class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
     r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
 
@@ -965,10 +1005,12 @@ class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
 
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.adaptive_avg_pool2d(input, self.output_size)
 
 
+@weak_module
 class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
     r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
 
@@ -997,5 +1039,6 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
 
     """
 
+    @weak_script_method
     def forward(self, input):
         return F.adaptive_avg_pool3d(input, self.output_size)