From d76e411d8cbae9cdb61c2226182c80713cc30c41 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 4 Dec 2018 19:52:07 -0800 Subject: [PATCH] support conv transpose in script Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14775 Differential Revision: D13330491 Pulled By: eellison fbshipit-source-id: 432b327d6a33517ff53ea33c9f64700e81432332 --- torch/nn/modules/conv.py | 88 +++++++++++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 31 deletions(-) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index e9f89a1..0ee379e 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -6,7 +6,7 @@ from .. import functional as F from .. import init from .module import Module from .utils import _single, _pair, _triple -from ..._jit_internal import weak_module, weak_script_method +from ..._jit_internal import weak_module, weak_script_method, List @weak_module @@ -448,10 +448,15 @@ class Conv3d(_ConvNd): self.padding, self.dilation, self.groups) +@weak_module class _ConvTransposeMixin(object): + __constants__ = ['stride', 'padding', 'kernel_size', 'dim_size', + 'output_padding', 'groups', 'dilation', 'transposed', 'bias'] + @weak_script_method def forward(self, input, output_size=None): - output_padding = self._output_padding(input, output_size) + # type(Tensor, Optional[List[int]]) -> Tensor + output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size) func = self._backend.ConvNd( self.stride, self.padding, self.dilation, self.transposed, output_padding, self.groups) @@ -460,35 +465,48 @@ class _ConvTransposeMixin(object): else: return func(input, self.weight, self.bias) - def _output_padding(self, input, output_size): + @weak_script_method + def _output_padding(self, input, output_size, stride, padding, kernel_size): + # type: (Tensor, Optional[List[int]], List[int], List[int], List[int]) -> List[int] if output_size is None: - return self.output_padding - - output_size = list(output_size) - k = input.dim() - 2 - if len(output_size) == k + 2: - output_size = output_size[2:] - if len(output_size) != k: - raise ValueError( - "output_size must have {} or {} elements (got {})" - .format(k, k + 2, len(output_size))) - - def dim_size(d): - return ((input.size(d + 2) - 1) * self.stride[d] - - 2 * self.padding[d] + self.kernel_size[d]) - - min_sizes = [dim_size(d) for d in range(k)] - max_sizes = [min_sizes[d] + self.stride[d] - 1 for d in range(k)] - for size, min_size, max_size in zip(output_size, min_sizes, max_sizes): - if size < min_size or size > max_size: - raise ValueError(( - "requested an output size of {}, but valid sizes range " - "from {} to {} (for an input of {})").format( - output_size, min_sizes, max_sizes, input.size()[2:])) - - return tuple([output_size[d] - min_sizes[d] for d in range(k)]) + ret = _single(self.output_padding) # converting to list if was not already + else: + output_size = torch.jit._unwrap_optional(output_size) + k = input.dim() - 2 + if len(output_size) == k + 2: + output_size = output_size[2:] + if len(output_size) != k: + raise ValueError( + "output_size must have {} or {} elements (got {})" + .format(k, k + 2, len(output_size))) + + min_sizes = torch.jit.annotate(List[int], []) + max_sizes = torch.jit.annotate(List[int], []) + for d in range(k): + dim_size = ((input.size(d + 2) - 1) * stride[d] - + 2 * padding[d] + kernel_size[d]) + min_sizes.append(dim_size) + max_sizes.append(min_sizes[d] + stride[d] - 1) + + for i in range(len(output_size)): + size = output_size[i] + min_size = min_sizes[i] + max_size = max_sizes[i] + if size < min_size or size > max_size: + raise ValueError(( + "requested an output size of {}, but valid sizes range " + "from {} to {} (for an input of {})").format( + output_size, min_sizes, max_sizes, input.size()[2:])) + + res = torch.jit.annotate(List[int], []) + for d in range(k): + res.append(output_size[d] - min_sizes[d]) + + ret = res + return ret +@weak_module class ConvTranspose1d(_ConvTransposeMixin, _ConvNd): r"""Applies a 1D transposed convolution operator over an input image composed of several input planes. @@ -587,13 +605,16 @@ class ConvTranspose1d(_ConvTransposeMixin, _ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, True, output_padding, groups, bias) + @weak_script_method def forward(self, input, output_size=None): - output_padding = self._output_padding(input, output_size) + # type: (Tensor, Optional[List[int]]) -> Tensor + output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size) return F.conv_transpose1d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) +@weak_module class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): r"""Applies a 2D transposed convolution operator over an input image composed of several input planes. @@ -727,13 +748,16 @@ class ConvTranspose2d(_ConvTransposeMixin, _ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, True, output_padding, groups, bias) + @weak_script_method def forward(self, input, output_size=None): - output_padding = self._output_padding(input, output_size) + # type: (Tensor, Optional[List[int]]) -> Tensor + output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size) return F.conv_transpose2d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) +@weak_module class ConvTranspose3d(_ConvTransposeMixin, _ConvNd): r"""Applies a 3D transposed convolution operator over an input image composed of several input planes. @@ -862,8 +886,10 @@ class ConvTranspose3d(_ConvTransposeMixin, _ConvNd): in_channels, out_channels, kernel_size, stride, padding, dilation, True, output_padding, groups, bias) + @weak_script_method def forward(self, input, output_size=None): - output_padding = self._output_padding(input, output_size) + # type: (Tensor, Optional[List[int]]) -> Tensor + output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size) return F.conv_transpose3d( input, self.weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) -- 2.7.4