From 666d383a0066ffb5bee2d2062b7fa1c94c8f1a7e Mon Sep 17 00:00:00 2001 From: David Riazati Date: Thu, 29 Nov 2018 15:13:45 -0800 Subject: [PATCH] Add broadcast list default arg support (#14361) Summary: To convert `max_unpool` functions to weak script, this PR adds support for `T` as default arguments for `BroadcastingListN[T]`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14361 Differential Revision: D13192231 Pulled By: driazati fbshipit-source-id: a25b75a0e88ba3dfa22d6a83775e9778d735e249 --- test/test_jit.py | 22 +++++++++---- test/test_nn.py | 7 ++-- torch/csrc/jit/script/init.cpp | 10 +++++- torch/nn/functional.py | 72 ++++++++++++++++++++++++++---------------- 4 files changed, 73 insertions(+), 38 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 63a22ab..f247b67 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9406,13 +9406,11 @@ EXCLUDE_SCRIPT = { # argument has custom behavior 'test_nn_fractional_max_pool2d', - 'test_nn_max_unpool3d', 'test_nn_batch_norm', # aten op has additional cudnn argument 'test_nn_group_norm', 'test_nn_unfold', - 'test_nn_max_unpool2d', # flakey test - TODO fix 'test_nn_ctc_loss', @@ -9423,7 +9421,12 @@ EXCLUDE_SCRIPT = { 'test_nn_cross_entropy', 'test_nn_interpolate', 'test_nn_fold', +} + +EXCLUDE_PYTHON_PRINT = { 'test_nn_max_unpool1d', + 'test_nn_max_unpool2d', + 'test_nn_max_unpool3d', } EXCLUDE_SCRIPT_MODULES = { @@ -10204,10 +10207,17 @@ def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=() if test_name not in EXCLUDE_SCRIPT: disable_ad_subgraph_inlining = test_name in DISABLE_AUTODIFF_SUBGRAPH_INLINING - check_against_reference(self, - create_script_fn(self, name, 'nn_functional', output_process_fn, - disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining), - fn, f_args_variable, kwargs_variable, no_grad=no_grad) + + def run_test(): + script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn, + disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining) + check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad) + + if test_name in EXCLUDE_PYTHON_PRINT: + with self.disableModuleHook(): + run_test() + else: + run_test() post_add_test(test_name, skipTestIf, do_test) diff --git a/test/test_nn.py b/test/test_nn.py index 465c565..0746d68 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3952,10 +3952,9 @@ class TestNN(NNTestCase): for w in range(3, 10): if 4 <= h <= 6 and 4 <= w <= 6: size = (h, w) - if h == 5: - size = torch.LongStorage(size) - elif h == 6: - size = torch.LongStorage((1, 1) + size) + if h == 6: + size = (1, 1) + size + mu(output_small, indices_small, output_size=size) else: self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w))) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index b25edab..8306103 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -459,7 +459,15 @@ FunctionSchema getSchemaWithNameAndDefaults( auto it = default_args.find(arg.name()); if (it != default_args.end()) { try { - IValue value = toIValue(it->second, arg.type()); + IValue value; + auto n = arg.N(); + auto list_type = arg.type()->cast(); + if (n && *n > 0 && list_type) { + // BroadcastingList, allow default values T for arg types List[T] + value = toIValue(it->second, list_type->getElementType()); + } else { + value = toIValue(it->second, arg.type()); + } new_args.emplace_back( Argument(arg.name(), arg.type(), arg.N(), value, arg.kwarg_only())); } catch (py::cast_error& e) { diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 5a8ba25..56feb47 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -446,75 +446,93 @@ max_pool3d = torch._jit_internal.boolean_dispatch( if_false=_max_pool3d) +@torch._jit_internal.weak_script def _unpool_output_size(input, kernel_size, stride, padding, output_size): + # type: (Tensor, List[int], List[int], List[int], Optional[List[int]]) -> List[int] input_size = input.size() - default_size = [] + default_size = torch.jit.annotate(List[int], []) for d in range(len(kernel_size)): default_size.append((input_size[d + 2] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]) if output_size is None: - return default_size - - output_size = list(output_size) - if len(output_size) == len(kernel_size) + 2: - output_size = output_size[2:] - if len(output_size) != len(kernel_size): - raise ValueError("output_size should be a sequence containing " - "{} or {} elements, but it has a length of '{}'" - .format(len(kernel_size), len(kernel_size) + 2, - len(output_size))) - for d in range(len(kernel_size)): - min_size = default_size[d] - stride[d] - max_size = default_size[d] + stride[d] - if not (min_size < output_size[d] < max_size): - raise ValueError( - 'invalid output_size "{}" (dim {} must be between {} and {})' - .format(output_size, d, min_size, max_size)) - - return output_size + ret = default_size + else: + output_size = torch.jit._unwrap_optional(output_size) + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError("output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'" + .format(len(kernel_size), len(kernel_size) + 2, + len(output_size))) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + 'invalid output_size "{}" (dim {} must be between {} and {})' + .format(output_size, d, min_size, max_size)) + + ret = output_size + return ret @torch._jit_internal.weak_script def max_unpool1d(input, indices, kernel_size, stride=None, padding=0, output_size=None): + # type: (Tensor, Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], BroadcastingList1[int], Optional[BroadcastingList1[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool1d`. See :class:`~torch.nn.MaxUnpool1d` for details. """ kernel_size = _single(kernel_size) - stride = _single(stride or kernel_size) + if stride is not None: + _stride = _single(torch.jit._unwrap_optional(stride)) + else: + _stride = kernel_size padding = _single(padding) - output_size = _unpool_output_size(input, kernel_size, stride, padding, + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) return torch._C._nn.max_unpool2d(input.unsqueeze(3), indices.unsqueeze(3), output_size + [1]).squeeze(3) +@torch._jit_internal.weak_script def max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None): + # type: (Tensor, Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], BroadcastingList2[int], Optional[BroadcastingList2[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool2d`. See :class:`~torch.nn.MaxUnpool2d` for details. """ kernel_size = _pair(kernel_size) - stride = _pair(stride or kernel_size) + if stride is not None: + _stride = _pair(torch.jit._unwrap_optional(stride)) + else: + _stride = kernel_size padding = _pair(padding) - output_size = _unpool_output_size(input, kernel_size, stride, padding, + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) return torch._C._nn.max_unpool2d(input, indices, output_size) +@torch._jit_internal.weak_script def max_unpool3d(input, indices, kernel_size, stride=None, padding=0, output_size=None): + # type: (Tensor, Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], BroadcastingList3[int], Optional[BroadcastingList3[int]]) -> Tensor # noqa r"""Computes a partial inverse of :class:`MaxPool3d`. See :class:`~torch.nn.MaxUnpool3d` for details. """ kernel_size = _triple(kernel_size) - stride = _triple(stride or kernel_size) + if stride is not None: + _stride = _triple(torch.jit._unwrap_optional(stride)) + else: + _stride = kernel_size padding = _triple(padding) - output_size = _unpool_output_size(input, kernel_size, stride, padding, + output_size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) - return torch._C._nn.max_unpool3d(input, indices, output_size, stride, padding) + return torch._C._nn.max_unpool3d( + input, indices, output_size, _stride, padding) @torch._jit_internal.weak_script -- 2.7.4