From 1b80644b4d6280a37ebbea67dc02a2b2540ac5e0 Mon Sep 17 00:00:00 2001 From: David Riazati Date: Tue, 27 Nov 2018 13:12:14 -0800 Subject: [PATCH] Revert D13192228: [pytorch][PR] [jit] Add boolean dispatch for function overloading Differential Revision: D13192228 Original commit changeset: fce33c400c1f fbshipit-source-id: 75c9991dc7097f9513c6c89d16eff2de6e287c3b --- test/test_jit.py | 44 ++--------- torch/_jit_internal.py | 45 ----------- torch/csrc/jit/script/compiler.h | 5 -- torch/csrc/jit/script/init.cpp | 62 --------------- torch/jit/__init__.py | 8 +- torch/nn/functional.py | 158 +++++++-------------------------------- 6 files changed, 34 insertions(+), 288 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 0089330..d5263de 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8714,42 +8714,6 @@ a") return foo, a, bar self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) - def test_bool_dispatch(self): - def kwarg_false(x): - # type: (Tensor) -> Tensor - return F.max_pool1d(x, 1, 1, return_indices=False) - self.checkScript(kwarg_false, (torch.randn(3, 3, 3),)) - - def kwarg_true(x): - # type: (Tensor) -> Tuple[Tensor, Tensor] - return F.max_pool1d(x, 1, 1, return_indices=True) - self.checkScript(kwarg_true, (torch.randn(3, 3, 3),)) - - def full_kwarg_false(x): - # type: (Tensor) -> Tensor - return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False) - self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),)) - - def full_kwarg_true(x): - # type: (Tensor) -> Tuple[Tensor, Tensor] - return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True) - self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),)) - - def use_default(x): - # type: (Tensor) -> Tensor - return F.max_pool1d(x, 1, 1) - self.checkScript(use_default, (torch.randn(3, 3, 3),)) - - def arg_false(x): - # type: (Tensor) -> Tensor - return F.max_pool1d(x, 1, 1, 0, 1, False, False) - self.checkScript(arg_false, (torch.randn(3, 3, 3),)) - - def arg_true(x): - # type: (Tensor) -> Tuple[Tensor, Tensor] - return F.max_pool1d(x, 1, 1, 0, 1, False, True) - self.checkScript(arg_true, (torch.randn(3, 3, 3),)) - class MnistNet(nn.Module): def __init__(self): @@ -9353,6 +9317,13 @@ EXCLUDE_SCRIPT = { 'test_norm_fro', 'test_norm_fro_default', 'test_norm_nuc', + # skipped nn functional tests + # ops involves sampling which could not test + + 'test_nn_adaptive_max_pool1d', + 'test_nn_adaptive_max_pool2d', + 'test_nn_adaptive_max_pool3d', + # argument has custom behavior 'test_nn_fractional_max_pool2d', @@ -9940,7 +9911,6 @@ nn_functional_tests = [ ('avg_pool3d', (S, S, S, S, S), (3,)), ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3], None)), ('max_pool1d', (S, S, S), (2, 1)), - ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'), ('max_pool2d', (S, S, S, S), (2, 1)), ('max_pool3d', (S, S, S, S, S), (2, 1)), ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)), diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index f1b736e..07b301c 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -23,10 +23,6 @@ _weak_modules = weakref.WeakKeyDictionary() # Types that have been declared as weak modules _weak_types = weakref.WeakKeyDictionary() -# Wrapper functions that can call either of 2 functions depending on a boolean -# argument -_boolean_dispatched = weakref.WeakKeyDictionary() - COMPILATION_PENDING = object() COMPILED = object() @@ -108,44 +104,3 @@ def weak_script_method(fn): "original_method": fn } return fn - - -def boolean_dispatch(arg_name, arg_index, default, if_true, if_false): - """ - Dispatches to either of 2 weak script functions based on a boolean argument. - In Torch Script, the boolean argument must be constant so that the correct - function to use can be determined at compile time. - """ - if _compiled_weak_fns.get(if_true) is None or _compiled_weak_fns.get(if_false) is None: - raise RuntimeError("both functions must be weak script") - - def fn(*args, **kwargs): - dispatch_flag = False - if arg_name in kwargs: - dispatch_flag = kwargs[arg_name] - elif arg_index < len(args): - dispatch_flag = args[arg_index] - - if dispatch_flag: - return if_true(*args, **kwargs) - else: - return if_false(*args, **kwargs) - - if if_true.__doc__ is None and if_false.__doc__ is not None: - doc = if_false.__doc__ - if_true.__doc__ = doc - elif if_false.__doc__ is None and if_true.__doc__ is not None: - doc = if_true.__doc__ - if_false.__doc__ = doc - else: - raise RuntimeError("only one function can have a docstring") - fn.__doc__ = doc - - _boolean_dispatched[fn] = { - "if_true": if_true, - "if_false": if_false, - "index": arg_index, - "default": default, - "arg_name": arg_name - } - return fn diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h index 9bd1016..089cc24 100644 --- a/torch/csrc/jit/script/compiler.h +++ b/torch/csrc/jit/script/compiler.h @@ -232,11 +232,6 @@ TORCH_API Value* emitBuiltinCall( // if true, emitBuiltinCall will throw an exception if this builtin does not exist, // otherwise it will return nullptr if the builtin is not found. bool required); - -c10::optional findInputWithName( - const std::string& name, - at::ArrayRef kwargs); - } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index f44ae7a..459d5c7 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -261,62 +261,6 @@ struct ModuleValue : public SugaredValue { std::shared_ptr module; }; -struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { - BooleanDispatchValue(py::dict dispatched_fn) - : dispatched_fn_(std::move(dispatched_fn)) {} - - std::string kind() const override { - return "boolean dispatch"; - } - - std::vector removeIndex( - at::ArrayRef arr, - size_t index) { - auto sliced = arr.vec(); - sliced.erase(sliced.begin() + index); - return sliced; - } - - std::shared_ptr call( - SourceRange loc, - Method& caller, - at::ArrayRef inputs, - at::ArrayRef attributes, - size_t n_binders) override { - c10::optional result; - Graph& graph = *(caller.graph()); - - auto index = py::cast(dispatched_fn_["index"]); - auto arg_name = py::str(dispatched_fn_["arg_name"]); - - if (index < inputs.size()) { - // Dispatch flag is in arg list - result = constant_as(inputs.at(index).value(graph)); - } else if (auto i = findInputWithName(arg_name, attributes)) { - // Dispatch flag is in kwargs - result = constant_as(attributes[*i].value(graph)); - } else { - // Didn't find dispatch flag, so use default value - result = py::cast(dispatched_fn_["default"]); - } - - if (!result) { - throw ErrorReport(loc) << "value for boolean dispatch was not constant"; - } - - std::shared_ptr value; - if (*result) { - value = toSugaredValue(dispatched_fn_["if_true"], caller, loc); - } else { - value = toSugaredValue(dispatched_fn_["if_false"], caller, loc); - } - return value->call(loc, caller, inputs, attributes, n_binders); - } - - private: - py::dict dispatched_fn_; -}; - std::shared_ptr toSugaredValue( py::object obj, Method& m, @@ -393,12 +337,6 @@ std::shared_ptr toSugaredValue( return std::make_shared(mod); } } - - py::object dispatched_fn = - py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj); - if (!dispatched_fn.is_none()) { - return std::make_shared(std::move(dispatched_fn)); - } return std::make_shared(obj); } diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index ecad9a1..5d9b227 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -7,7 +7,7 @@ import torch.jit.annotations from torch._six import raise_from, with_metaclass, get_function_from_type from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \ _weak_script_methods, _weak_modules, _weak_types, COMPILED, \ - COMPILATION_PENDING, _boolean_dispatched + COMPILATION_PENDING from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \ _list_with_default import torch.testing @@ -637,10 +637,6 @@ class CompilationUnit(object): return self.module._get_method(attr) -def _try_get_dispatched_fn(fn): - return _boolean_dispatched.get(fn) - - def _try_compile_weak_script(fn): entry = _compiled_weak_fns.get(fn) if entry is None: @@ -1354,7 +1350,7 @@ def _should_skip(mod, name): func = getattr(torch.nn.functional, name) if func is None: return False - return func in _compiled_weak_fns or func in _boolean_dispatched + return func in _compiled_weak_fns def _unwrap_optional(x): diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 9b94bc9..91352fc 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4,7 +4,6 @@ from __future__ import division import warnings import math import types -from typing import List import torch from torch._C import _infer_size, _add_docstr @@ -305,7 +304,6 @@ Args: def fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None): - # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList1[int]], float, bool, Tensor) -> Tuple[Tensor, Tensor] # noqa r"""Applies 2D fractional max pooling over an input signal composed of several input planes. Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham @@ -339,111 +337,47 @@ def fractional_max_pool2d(input, kernel_size, output_size=None, raise ValueError("fractional_max_pool2d requires specifying either " "an output_size, or a output_ratio") if output_size is None: - _output_ratio = _pair(output_ratio) - _output_size = (int(input.size(2) * _output_ratio[0]), - int(input.size(3) * _output_ratio[1])) - else: - _output_size = torch.jit._unwrap_optional(output_size) + output_ratio = _pair(output_ratio) + output_size = (int(input.size(2) * output_ratio[0]), + int(input.size(3) * output_ratio[1])) if _random_samples is None: _random_samples = input.new(input.size(0), input.size(1), 2).uniform_() - ret = torch._C._nn.fractional_max_pool2d(input, kernel_size, _output_size, _random_samples) + ret = torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples) return ret if return_indices else ret[0] -@torch._jit_internal.weak_script -def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): - # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor] # noqa +def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): r"""Applies a 1D max pooling over an input signal composed of several input planes. See :class:`~torch.nn.MaxPool1d` for details. """ - if stride is None: - _stride = torch.jit.annotate(List[int], []) - else: - _stride = torch.jit._unwrap_optional(stride) - return torch.max_pool1d_with_indices( - input, kernel_size, _stride, padding, dilation, ceil_mode) - - -@torch._jit_internal.weak_script -def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): - # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tensor - return max_pool1d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode)[0] - -max_pool1d = torch._jit_internal.boolean_dispatch( - arg_name='return_indices', - arg_index=6, - default=False, - if_true=max_pool1d_with_indices, - if_false=_max_pool1d) + ret = torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + return ret if return_indices else ret[0] -@torch._jit_internal.weak_script -def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): - # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor] # noqa +def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): r"""Applies a 2D max pooling over an input signal composed of several input planes. See :class:`~torch.nn.MaxPool2d` for details. """ - if stride is None: - _stride = torch.jit.annotate(List[int], []) - else: - _stride = torch.jit._unwrap_optional(stride) - return torch._C._nn.max_pool2d_with_indices(input, kernel_size, _stride, padding, dilation, ceil_mode) - - -@torch._jit_internal.weak_script -def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): - # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tensor - return max_pool2d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode)[0] - -max_pool2d = torch._jit_internal.boolean_dispatch( - arg_name='return_indices', - arg_index=6, - default=False, - if_true=max_pool2d_with_indices, - if_false=_max_pool2d) + ret = torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + return ret if return_indices else ret[0] -@torch._jit_internal.weak_script -def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0, - dilation=1, ceil_mode=False, return_indices=False): - # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor] # noqa +def max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, + ceil_mode=False, return_indices=False): r"""Applies a 3D max pooling over an input signal composed of several input planes. See :class:`~torch.nn.MaxPool3d` for details. """ - if stride is None: - _stride = torch.jit.annotate(List[int], []) - else: - _stride = torch.jit._unwrap_optional(stride) - return torch._C._nn.max_pool3d_with_indices( - input, kernel_size, _stride, padding, dilation, ceil_mode) - - -@torch._jit_internal.weak_script -def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, - ceil_mode=False, return_indices=False): - # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tensor - return max_pool3d_with_indices( - input, kernel_size, stride, padding, dilation, ceil_mode)[0] - -max_pool3d = torch._jit_internal.boolean_dispatch( - arg_name='return_indices', - arg_index=6, - default=False, - if_true=max_pool3d_with_indices, - if_false=_max_pool3d) + ret = torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + return ret if return_indices else ret[0] def _unpool_output_size(input, kernel_size, stride, padding, output_size): @@ -554,9 +488,7 @@ def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type) -@torch._jit_internal.weak_script -def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): - # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor] +def adaptive_max_pool1d(input, output_size, return_indices=False): r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes. @@ -566,25 +498,11 @@ def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False): output_size: the target output size (single integer) return_indices: whether to return pooling indices. Default: ``False`` """ - return torch.adaptive_max_pool1d(input, output_size) - - -@torch._jit_internal.weak_script -def _adaptive_max_pool1d(input, output_size, return_indices=False): - # type: (Tensor, BroadcastingList1[int], bool) -> Tensor - return adaptive_max_pool1d_with_indices(input, output_size)[0] - -adaptive_max_pool1d = torch._jit_internal.boolean_dispatch( - arg_name='return_indices', - arg_index=2, - default=False, - if_true=adaptive_max_pool1d_with_indices, - if_false=_adaptive_max_pool1d) + ret = torch.adaptive_max_pool1d(input, output_size) + return ret if return_indices else ret[0] -@torch._jit_internal.weak_script -def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): - # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor] +def adaptive_max_pool2d(input, output_size, return_indices=False): r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes. @@ -596,25 +514,11 @@ def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False): return_indices: whether to return pooling indices. Default: ``False`` """ output_size = _list_with_default(output_size, input.size()) - return torch._C._nn.adaptive_max_pool2d(input, output_size) - - -@torch._jit_internal.weak_script -def _adaptive_max_pool2d(input, output_size, return_indices=False): - # type: (Tensor, BroadcastingList1[int], bool) -> Tensor - return adaptive_max_pool2d_with_indices(input, output_size)[0] - -adaptive_max_pool2d = torch._jit_internal.boolean_dispatch( - arg_name='return_indices', - arg_index=2, - default=False, - if_true=adaptive_max_pool2d_with_indices, - if_false=_adaptive_max_pool2d) + ret = torch._C._nn.adaptive_max_pool2d(input, output_size) + return ret if return_indices else ret[0] -@torch._jit_internal.weak_script -def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): - # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor] +def adaptive_max_pool3d(input, output_size, return_indices=False): r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes. @@ -626,20 +530,8 @@ def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False): return_indices: whether to return pooling indices. Default: ``False`` """ output_size = _list_with_default(output_size, input.size()) - return torch._C._nn.adaptive_max_pool3d(input, output_size) - - -@torch._jit_internal.weak_script -def _adaptive_max_pool3d(input, output_size, return_indices=False): - # type: (Tensor, BroadcastingList1[int], bool) -> Tensor - return adaptive_max_pool3d_with_indices(input, output_size)[0] - -adaptive_max_pool3d = torch._jit_internal.boolean_dispatch( - arg_name='return_indices', - arg_index=2, - default=False, - if_true=adaptive_max_pool3d_with_indices, - if_false=_adaptive_max_pool3d) + ret = torch._C._nn.adaptive_max_pool3d(input, output_size) + return ret if return_indices else ret[0] adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r""" -- 2.7.4