support conv transpose in script
authorElias Ellison <eellison@fb.com>
Wed, 5 Dec 2018 03:52:07 +0000 (19:52 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 5 Dec 2018 03:54:09 +0000 (19:54 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14775

Differential Revision: D13330491

Pulled By: eellison

fbshipit-source-id: 432b327d6a33517ff53ea33c9f64700e81432332

torch/nn/modules/conv.py

index e9f89a1..0ee379e 100644 (file)
@@ -6,7 +6,7 @@ from .. import functional as F
 from .. import init
 from .module import Module
 from .utils import _single, _pair, _triple
-from ..._jit_internal import weak_module, weak_script_method
+from ..._jit_internal import weak_module, weak_script_method, List
 
 
 @weak_module
@@ -448,10 +448,15 @@ class Conv3d(_ConvNd):
                         self.padding, self.dilation, self.groups)
 
 
+@weak_module
 class _ConvTransposeMixin(object):
+    __constants__ = ['stride', 'padding', 'kernel_size', 'dim_size',
+                     'output_padding', 'groups', 'dilation', 'transposed', 'bias']
 
+    @weak_script_method
     def forward(self, input, output_size=None):
-        output_padding = self._output_padding(input, output_size)
+        # type(Tensor, Optional[List[int]]) -> Tensor
+        output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
         func = self._backend.ConvNd(
             self.stride, self.padding, self.dilation, self.transposed,
             output_padding, self.groups)
@@ -460,35 +465,48 @@ class _ConvTransposeMixin(object):
         else:
             return func(input, self.weight, self.bias)
 
-    def _output_padding(self, input, output_size):
+    @weak_script_method
+    def _output_padding(self, input, output_size, stride, padding, kernel_size):
+        # type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int]
         if output_size is None:
-            return self.output_padding
-
-        output_size = list(output_size)
-        k = input.dim() - 2
-        if len(output_size) == k + 2:
-            output_size = output_size[2:]
-        if len(output_size) != k:
-            raise ValueError(
-                "output_size must have {} or {} elements (got {})"
-                .format(k, k + 2, len(output_size)))
-
-        def dim_size(d):
-            return ((input.size(d + 2) - 1) * self.stride[d] -
-                    2 * self.padding[d] + self.kernel_size[d])
-
-        min_sizes = [dim_size(d) for d in range(k)]
-        max_sizes = [min_sizes[d] + self.stride[d] - 1 for d in range(k)]
-        for size, min_size, max_size in zip(output_size, min_sizes, max_sizes):
-            if size < min_size or size > max_size:
-                raise ValueError((
-                    "requested an output size of {}, but valid sizes range "
-                    "from {} to {} (for an input of {})").format(
-                        output_size, min_sizes, max_sizes, input.size()[2:]))
-
-        return tuple([output_size[d] - min_sizes[d] for d in range(k)])
+            ret = _single(self.output_padding)  # converting to list if was not already
+        else:
+            output_size = torch.jit._unwrap_optional(output_size)
+            k = input.dim() - 2
+            if len(output_size) == k + 2:
+                output_size = output_size[2:]
+            if len(output_size) != k:
+                raise ValueError(
+                    "output_size must have {} or {} elements (got {})"
+                    .format(k, k + 2, len(output_size)))
+
+            min_sizes = torch.jit.annotate(List[int], [])
+            max_sizes = torch.jit.annotate(List[int], [])
+            for d in range(k):
+                dim_size = ((input.size(d + 2) - 1) * stride[d] -
+                            2 * padding[d] + kernel_size[d])
+                min_sizes.append(dim_size)
+                max_sizes.append(min_sizes[d] + stride[d] - 1)
+
+            for i in range(len(output_size)):
+                size = output_size[i]
+                min_size = min_sizes[i]
+                max_size = max_sizes[i]
+                if size < min_size or size > max_size:
+                    raise ValueError((
+                        "requested an output size of {}, but valid sizes range "
+                        "from {} to {} (for an input of {})").format(
+                            output_size, min_sizes, max_sizes, input.size()[2:]))
+
+            res = torch.jit.annotate(List[int], [])
+            for d in range(k):
+                res.append(output_size[d] - min_sizes[d])
+
+            ret = res
+        return ret
 
 
+@weak_module
 class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
     r"""Applies a 1D transposed convolution operator over an input image
     composed of several input planes.
@@ -587,13 +605,16 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd):
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             True, output_padding, groups, bias)
 
+    @weak_script_method
     def forward(self, input, output_size=None):
-        output_padding = self._output_padding(input, output_size)
+        # type: (Tensor, Optional[List[int]]) -> Tensor
+        output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
         return F.conv_transpose1d(
             input, self.weight, self.bias, self.stride, self.padding,
             output_padding, self.groups, self.dilation)
 
 
+@weak_module
 class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
     r"""Applies a 2D transposed convolution operator over an input image
     composed of several input planes.
@@ -727,13 +748,16 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd):
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             True, output_padding, groups, bias)
 
+    @weak_script_method
     def forward(self, input, output_size=None):
-        output_padding = self._output_padding(input, output_size)
+        # type: (Tensor, Optional[List[int]]) -> Tensor
+        output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
         return F.conv_transpose2d(
             input, self.weight, self.bias, self.stride, self.padding,
             output_padding, self.groups, self.dilation)
 
 
+@weak_module
 class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
     r"""Applies a 3D transposed convolution operator over an input image composed of several input
     planes.
@@ -862,8 +886,10 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd):
             in_channels, out_channels, kernel_size, stride, padding, dilation,
             True, output_padding, groups, bias)
 
+    @weak_script_method
     def forward(self, input, output_size=None):
-        output_padding = self._output_padding(input, output_size)
+        # type: (Tensor, Optional[List[int]]) -> Tensor
+        output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
         return F.conv_transpose3d(
             input, self.weight, self.bias, self.stride, self.padding,
             output_padding, self.groups, self.dilation)