From c5e1b469beab9edc7a0fb0ab9da1132b795de6c3 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Tue, 22 Jan 2019 11:09:18 -0800 Subject: [PATCH] Return namedtuples from torch.* function with multiple return arguments for C++ operators (#15429) Summary: Partially fixes: https://github.com/pytorch/pytorch/issues/394 Implementation detail: Codegen is modified to generate codes that looks like below: ```C++ static PyObject * THPVariable_svd(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None)", }, /*traceable=*/true); ParsedArgs<6> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); static PyStructSequence_Field fields0[] = { {"U", ""}, {"S", ""}, {"V", ""}, {nullptr} }; static PyStructSequence_Desc desc0 = { "torch.return_types.svd_out", nullptr, fields0, 3 }; static PyTypeObject type0; static bool namedtuple_type_initialized0 = false; if (!namedtuple_type_initialized0) { PyStructSequence_InitType(&type0, &desc0); namedtuple_type_initialized0 = true; } static PyStructSequence_Field fields1[] = { {"U", ""}, {"S", ""}, {"V", ""}, {nullptr} }; static PyStructSequence_Desc desc1 = { "torch.return_types.svd", nullptr, fields1, 3 }; static PyTypeObject type1; static bool namedtuple_type_initialized1 = false; if (!namedtuple_type_initialized1) { PyStructSequence_InitType(&type1, &desc1); namedtuple_type_initialized1 = true; } if (r.idx == 0) { if (r.isNone(3)) { return wrap(&type1, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2))); } else { auto results = r.tensorlist_n<3>(3); return wrap(&type0, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2), results[0], results[1], results[2])); } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } ``` Types are defined as static member of `THPVariable_${op_name}` functions, and initialized at the first time the function is called. When parsing function prototypes in `native_functions.yaml`, the parser will set the specified name as `field_name` when see things like `-> (Tensor t1, ...)`. These field names will be the field names of namedtuple. The class of namedtuples will be named `torch.return_types.${op_name}`. In some python 2, `PyStructSequence` is not a subtype of tuple, so we have to create some functions to check if an object is a tuple or namedtuple for compatibility issue. Operators in `native_functions.yaml` are changed such that only `max` and `svd` are generated as namedtuple. Tests are added for these two operators to see if the return value works as expected. Docs for these two ops are also updated to explicitly mention the return value is a namedtuple. More ops will be added in later PRs. There is some issue with Windows build of linker unable to resolve `PyStructSequence_UnnamedField`, and some workaround is added to deal with this case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15429 Differential Revision: D13709678 Pulled By: ezyang fbshipit-source-id: 23a511c9436977098afc49374e9a748b6e30bccf --- aten/src/ATen/function_wrapper.py | 27 ++++++++++++++ aten/src/ATen/native/README.md | 11 ++++-- aten/src/ATen/native/native_functions.yaml | 37 ++++++++++++------- aten/src/ATen/native_parse.py | 13 +++++++ test/common_methods_invocations.py | 4 +-- test/test_autograd.py | 4 +-- test/test_torch.py | 30 ++++++++++++++-- tools/autograd/derivatives.yaml | 14 ++++---- tools/autograd/gen_autograd.py | 4 ++- tools/autograd/gen_python_functions.py | 57 ++++++++++++++++++++++++++++-- torch/_six.py | 15 ++++++++ torch/_torch_docs.py | 14 ++++---- torch/autograd/gradcheck.py | 4 +-- torch/csrc/autograd/utils/wrap_outputs.h | 17 +++++++++ torch/csrc/jit/pybind_utils.h | 3 +- torch/csrc/jit/python_arg_flatten.cpp | 3 +- torch/csrc/python_headers.h | 1 + torch/csrc/utils/six.h | 21 +++++++++++ 18 files changed, 238 insertions(+), 41 deletions(-) create mode 100644 torch/csrc/utils/six.h diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 6139bd5..4d5d0df 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -425,8 +425,27 @@ AtFormal = TypedDict('AtFormal', { 'size': int, }, total=False) +# Note [field_name versus name] +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# What is the difference between "field_name" and "name"? +# +# Return values of ATen operators always have a name: if it is not +# explicitly assigned a name inside native_functions.yaml like func: +# myop() -> (Tensor indices, Tensor value), then the codegen will +# automatically assign it a name like result0, or name might be +# specified inside Declarations.cwrap. We don't want these assigned +# names to become part of the public API when we return a namedtuple for +# any such multiple-return function. +# +# Thus field_name is like name, but it is defined only when there is a +# name specified in native_functions.yaml. If field_name is defined, +# then the codegen would generate code to return namedtuple. Otherwise, +# it would just return tuple. + ReturnType = TypedDict('ReturnType', { 'name': str, + # See Note [field_name versus name] + 'field_name': str, 'type': str, 'dynamic_type': str, }, total=False) @@ -465,6 +484,8 @@ FunctionOption = TypedDict('FunctionOption', { 'with_gil': bool, 'cpu_half': bool, 'deprecated': bool, + # See Note [field_name versus name] + 'field_name': str, 'formals_list': List[AtFormal], 'formals_with_defaults': List[str], 'formals': List[str], @@ -973,6 +994,8 @@ def create_generic(top_env, declarations): return_types = [] # List[ReturnType] for t_raw in ret: + # See Note [field_name versus name] + field_name = None if isinstance(t_raw, string_type): t = t_raw name = None @@ -982,6 +1005,8 @@ def create_generic(top_env, declarations): else: t = t_raw['type'] name = t_raw['name'] + if 'field_name' in t_raw: + field_name = t_raw['field_name'] # can't actually return a TensorList (since it's a reference object) actual_return_type = {'TensorList': 'std::vector'}.get(t, t) @@ -996,6 +1021,8 @@ def create_generic(top_env, declarations): } # type: ReturnType if name is not None: rtype['name'] = name + if field_name is not None: + rtype['field_name'] = field_name return_types.append(rtype) return return_types diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 3c2bb42..e510fc6 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -152,8 +152,15 @@ only simple, universal types, as well as a handful of fundamental Tensor structu (e.g., `Tensor` and `Generator*`), because these types can be easily ported to any language bound to ATen (in practice, C++ and Python.) -Return also supports specifying (optional) return argument names; these are useful for writing -derivatives in terms of return arguments in `tools/autograd/derivatives.yaml`. +Return also supports specifying (optional) return argument names. These serve +two functions: + +- They let you easily write derivatives in terms of return arguments in + `tools/autograd/derivatives.yaml` + +- They correspond to the named field the output can be referred to from + Python. (This means that changing a return argument name is + BC-breaking, be careful!) Note that argument type modifiers such as defaults and optional are not currently supported on Return. diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 63e2006..f4e51c1 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -115,6 +115,7 @@ - func: adaptive_avg_pool1d(Tensor self, IntList[1] output_size) -> Tensor +# Return: (Tensor output, Tensor indices) - func: adaptive_max_pool1d(Tensor self, IntList[1] output_size) -> (Tensor, Tensor) - func: add(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor @@ -1142,14 +1143,15 @@ - func: matrix_power(Tensor self, int64_t n) -> Tensor variants: function, method -- func: max(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) +- func: max(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor values, Tensor indices) variants: function, method -- func: max_out(Tensor max, Tensor max_values, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) +- func: max_out(Tensor max, Tensor max_values, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor values, Tensor indices) - func: max_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor variants: function, method +# Return: (Tensor output, Tensor indices) - func: max_pool1d_with_indices(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor) - func: max_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> Tensor @@ -3022,9 +3024,9 @@ - func: eig(Tensor self, bool eigenvectors=false) -> (Tensor, Tensor) variants: method, function -- func: svd_out(Tensor U, Tensor S, Tensor V, Tensor self, bool some=true, bool compute_uv=true) -> (Tensor, Tensor, Tensor) +- func: svd_out(Tensor U, Tensor S, Tensor V, Tensor self, bool some=true, bool compute_uv=true) -> (Tensor U, Tensor S, Tensor V) -- func: svd(Tensor self, bool some=true, bool compute_uv=true) -> (Tensor, Tensor, Tensor) +- func: svd(Tensor self, bool some=true, bool compute_uv=true) -> (Tensor U, Tensor S, Tensor V) variants: method, function - func: cholesky_out(Tensor result, Tensor self, bool upper=false) -> Tensor @@ -3576,10 +3578,12 @@ matches_jit_signature: True python_module: nn -- func: adaptive_max_pool2d_out(Tensor output, Tensor indices, Tensor self, IntList[2] output_size) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool2d_out(Tensor output, Tensor indices, Tensor self, IntList[2] output_size) -> (Tensor, Tensor) python_module: nn -- func: adaptive_max_pool2d(Tensor self, IntList[2] output_size) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool2d(Tensor self, IntList[2] output_size) -> (Tensor, Tensor) python_module: nn - func: adaptive_max_pool2d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, Tensor indices) -> Tensor @@ -3589,10 +3593,12 @@ matches_jit_signature: True python_module: nn -- func: adaptive_max_pool3d_out(Tensor output, Tensor indices, Tensor self, IntList[3] output_size) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool3d_out(Tensor output, Tensor indices, Tensor self, IntList[3] output_size) -> (Tensor, Tensor) python_module: nn -- func: adaptive_max_pool3d(Tensor self, IntList[3] output_size) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: adaptive_max_pool3d(Tensor self, IntList[3] output_size) -> (Tensor, Tensor) python_module: nn - func: adaptive_max_pool3d_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, Tensor indices) -> Tensor @@ -3626,13 +3632,15 @@ - func: avg_pool3d_backward(Tensor grad_output, Tensor self, IntList[3] kernel_size, IntList[3] stride, IntList[3] padding, bool ceil_mode, bool count_include_pad) -> Tensor python_module: nn -- func: fractional_max_pool2d_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: fractional_max_pool2d_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor, Tensor) python_module: nn dispatch: CPU: fractional_max_pool2d_out_cpu CUDA: fractional_max_pool2d_out_cuda -- func: fractional_max_pool2d(Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: fractional_max_pool2d(Tensor self, IntList[2] kernel_size, IntList[2] output_size, Tensor random_samples) -> (Tensor, Tensor) python_module: nn dispatch: CPU: fractional_max_pool2d_cpu @@ -3677,7 +3685,8 @@ - func: max_pool2d_with_indices_out(Tensor output, Tensor indices, Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false) -> (Tensor output, Tensor indices) python_module: nn -- func: max_pool2d_with_indices(Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: max_pool2d_with_indices(Tensor self, IntList[2] kernel_size, IntList[2] stride={}, IntList[2] padding=0, IntList[2] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor) python_module: nn - func: max_pool2d_with_indices_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[2] kernel_size, IntList[2] stride, IntList[2] padding, IntList[2] dilation, bool ceil_mode, Tensor indices) -> Tensor @@ -3686,10 +3695,12 @@ - func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, IntList[2] kernel_size, IntList[2] stride, IntList[2] padding, IntList[2] dilation, bool ceil_mode, Tensor indices) -> Tensor python_module: nn -- func: max_pool3d_with_indices_out(Tensor output, Tensor indices, Tensor self, IntList[3] kernel_size, IntList[3] stride={}, IntList[3] padding=0, IntList[3] dilation=1, bool ceil_mode=false) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: max_pool3d_with_indices_out(Tensor output, Tensor indices, Tensor self, IntList[3] kernel_size, IntList[3] stride={}, IntList[3] padding=0, IntList[3] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor) python_module: nn -- func: max_pool3d_with_indices(Tensor self, IntList[3] kernel_size, IntList[3] stride={}, IntList[3] padding=0, IntList[3] dilation=1, bool ceil_mode=false) -> (Tensor output, Tensor indices) +# Return: (Tensor output, Tensor indices) +- func: max_pool3d_with_indices(Tensor self, IntList[3] kernel_size, IntList[3] stride={}, IntList[3] padding=0, IntList[3] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor) python_module: nn - func: max_pool3d_with_indices_backward_out(Tensor grad_input, Tensor grad_output, Tensor self, IntList[3] kernel_size, IntList[3] stride, IntList[3] padding, IntList[3] dilation, bool ceil_mode, Tensor indices) -> Tensor diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index eb6b0c2..4008c94 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -105,6 +105,8 @@ def parse_return_arguments(return_decl, inplace): for arg_idx, arg in enumerate(return_decl.split(', ')): type_and_maybe_name = [a.strip() for a in arg.rsplit(' ', 1)] + # See Note [field_name versus name] + field_name = None if len(type_and_maybe_name) == 1: t = type_and_maybe_name[0] if inplace: @@ -113,9 +115,12 @@ def parse_return_arguments(return_decl, inplace): name = 'result' if not multiple_args else 'result' + str(arg_idx) else: t, name = type_and_maybe_name + field_name = name typ = sanitize_type(t) argument_dict = {'type': typ, 'name': name} + if field_name is not None: + argument_dict['field_name'] = field_name argument_dict['output'] = True arguments.append(argument_dict) @@ -134,6 +139,13 @@ def parse_native_yaml(path): return yaml.load(f, Loader=Loader) +def propagate_field_names(output_arguments, return_arguments): + if output_arguments: + for i, r in enumerate(return_arguments): + if 'field_name' in r: + output_arguments[i]['field_name'] = r['field_name'] + + def run(paths): declarations = [] for path in paths: @@ -152,6 +164,7 @@ def run(paths): return_arguments = parse_return_arguments(return_decl, declaration['inplace']) arguments = parse_arguments(arguments, func, declaration['name'], return_arguments) output_arguments = [x for x in arguments if x.get('output')] + propagate_field_names(output_arguments, return_arguments) declaration['return'] = return_arguments if len(output_arguments) == 0 else output_arguments declaration['variants'] = func.get('variants', ['function']) declaration['requires_tensor'] = func.get('requires_tensor', False) diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 4147dab..41c8763 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -1,5 +1,5 @@ import torch -from torch._six import inf, nan +from torch._six import inf, nan, istuple from functools import reduce, wraps from operator import mul, itemgetter from torch.autograd import Variable, Function, detect_anomaly @@ -988,7 +988,7 @@ def run_additional_tri_tests(self, device): def unpack_variables(args): - if isinstance(args, tuple): + if istuple(args): return tuple(unpack_variables(elem) for elem in args) else: return args diff --git a/test/test_autograd.py b/test/test_autograd.py index 4a8d951..6525870 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -10,7 +10,7 @@ from collections import OrderedDict from itertools import product from operator import mul, itemgetter from functools import reduce, wraps -from torch._six import inf, nan +from torch._six import inf, nan, istuple from torch.autograd.gradcheck import gradgradcheck, gradcheck from torch.autograd.function import once_differentiable from torch.autograd.profiler import profile @@ -2898,7 +2898,7 @@ def add_test( output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) if not exclude_tensor_method(name, test_name): output_tensor = getattr(self_tensor, name)(*args_tensor, **kwargs_variable) - if not isinstance(output_tensor, torch.Tensor) and not isinstance(output_tensor, tuple): + if not isinstance(output_tensor, torch.Tensor) and not istuple(output_tensor): output_tensor = torch.DoubleTensor((output_tensor,)) self.assertEqual(unpack_variables(output_variable), output_tensor) # TODO: check that both have changed after adding all inplace ops diff --git a/test/test_torch.py b/test/test_torch.py index 0836643..4522317 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -18,7 +18,7 @@ import re from torch._utils_internal import get_file_path, get_file_path_2 from torch.utils.dlpack import from_dlpack, to_dlpack from torch._utils import _rebuild_tensor -from torch._six import inf, nan, string_classes +from torch._six import inf, nan, string_classes, istuple from itertools import product, combinations, combinations_with_replacement from functools import reduce from torch import multiprocessing as mp @@ -1062,7 +1062,7 @@ class _TestTorchMixin(object): def fn(x, dim, keepdim=False, out=None): ans = fn_attr(x, dim, keepdim=keepdim, out=out) - return ans if not isinstance(ans, tuple) else ans[0] + return ans if not istuple(ans) else ans[0] def fn_tuple(x, dim, keepdim=False, out=None): return fn_attr(x, dim, keepdim=keepdim, out=out) @@ -6910,6 +6910,32 @@ class _TestTorchMixin(object): val2 = rec.select(-1, -1).data.abs()[0][0][0].sum() self.assertEqual(val1, val2, 1e-8, 'absolute value') + def test_namedtuple_return(self): + a = torch.randn(5, 5) + + # test max + ret = a.max(dim=0) + self.assertEqual(ret.values, ret[0]) + self.assertEqual(ret.indices, ret[1]) + ret1 = torch.max(a, dim=0, out=tuple(ret)) + self.assertEqual(ret1.values, ret1[0]) + self.assertEqual(ret1.indices, ret1[1]) + self.assertEqual(ret1.values, ret[0]) + self.assertEqual(ret1.indices, ret[1]) + + # test svd + ret = a.svd() + self.assertEqual(ret.U, ret[0]) + self.assertEqual(ret.S, ret[1]) + self.assertEqual(ret.V, ret[2]) + ret1 = torch.svd(a, out=tuple(ret)) + self.assertEqual(ret1.U, ret1[0]) + self.assertEqual(ret1.S, ret1[1]) + self.assertEqual(ret1.V, ret1[2]) + self.assertEqual(ret1.U, ret[0]) + self.assertEqual(ret1.S, ret[1]) + self.assertEqual(ret1.V, ret[2]) + def test_hardshrink(self): data_original = torch.tensor([1, 0.5, 0.3, 0.6]).view(2, 2) float_types = [ diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9ae87a7..621496d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -479,7 +479,7 @@ self: zeros_like(self.expand(at::infer_size(self.sizes(), mask.sizes()))).masked_scatter_(mask, grad) - name: max(Tensor self, int64_t dim, bool keepdim) - self: index_select_backward(grad, dim, result1, self.sizes(), keepdim) + self: index_select_backward(grad, dim, indices, self.sizes(), keepdim) - name: max(Tensor self) self: select_equals_backward(grad, self, result) @@ -791,7 +791,7 @@ self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) - name: svd(Tensor self, bool some, bool compute_uv) - self: svd_backward(grads, self, some, compute_uv, result0, result1, result2) + self: svd_backward(grads, self, some, compute_uv, U, S, V) - name: symeig(Tensor self, bool eigenvectors, bool upper) self: symeig_backward(grads, self, eigenvectors, upper, result0, result1) @@ -1067,10 +1067,10 @@ self: adaptive_avg_pool3d_backward(grad, self) - name: adaptive_max_pool2d(Tensor self, IntList output_size) - self: adaptive_max_pool2d_backward(grad, self, indices) + self: adaptive_max_pool2d_backward(grad, self, result1) - name: adaptive_max_pool3d(Tensor self, IntList output_size) - self: adaptive_max_pool3d_backward(grad, self, indices) + self: adaptive_max_pool3d_backward(grad, self, result1) - name: avg_pool2d(Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad) self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad) @@ -1079,16 +1079,16 @@ self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad) - name: fractional_max_pool2d(Tensor self, IntList kernel_size, IntList output_size, Tensor random_samples) - self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, indices) + self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1) - name: fractional_max_pool3d(Tensor self, IntList kernel_size, IntList output_size, Tensor random_samples) self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, indices) - name: max_pool2d_with_indices(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) - self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices) + self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) - name: max_pool3d_with_indices(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode) - self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices) + self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) - name: max_unpool2d(Tensor self, Tensor indices, IntList output_size) self: max_unpool2d_backward(grad, self, indices, output_size) diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 64c1285..5789fa3 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -223,8 +223,10 @@ def main(): help='path to Declarations.yaml') parser.add_argument('out', metavar='OUT', help='path to output directory') + parser.add_argument('autograd', metavar='AUTOGRAD', + help='path to autograd directory') args = parser.parse_args() - gen_autograd(args.declarations, args.out) + gen_autograd(args.declarations, args.out, args.autograd) if __name__ == '__main__': diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 4794b02..d837dd7 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -53,6 +53,7 @@ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) ${unpack_self} ParsedArgs<${max_args}> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); + ${declare_namedtuple_return_types} ${dispatch} Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -63,8 +64,9 @@ PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\ static PyObject * ${pycname}(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS + ${declare_namedtuple_return_types} ${unpack_self} - return wrap(${dispatch_name}(${actuals})); + return wrap(${namedtuple_return_type}${dispatch_name}(${actuals})); END_HANDLE_TH_ERRORS } """) @@ -100,7 +102,7 @@ PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\ ${call_dispatch}.set_requires_grad(${requires_grad})""") PY_VARIABLE_WRAP = CodeTemplate("""\ -return wrap(${call_dispatch});""") +return wrap(${namedtuple_return_type}${call_dispatch});""") PY_VARIABLE_DISPATCH = CodeTemplate("""\ inline ${simple_return_type} ${dispatch_name}(${formal_args}) { @@ -113,6 +115,22 @@ inline ${simple_return_type} ${dispatch_name}(${formal_args}) { PY_VARIABLE_METHOD_DEF = CodeTemplate("""\ {"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""") +PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate("""\ +static PyStructSequence_Field fields${namedtuple_type_index}[] = { + ${namedtuple_fields} {nullptr} +}; +static PyStructSequence_Desc desc${namedtuple_type_index} = { + "torch.return_types.${name}", nullptr, + fields${namedtuple_type_index}, ${namedtuple_size} +}; +static PyTypeObject type${namedtuple_type_index}; +static bool namedtuple_type_initialized${namedtuple_type_index} = false; +if (!namedtuple_type_initialized${namedtuple_type_index}) { + PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index}); + namedtuple_type_initialized${namedtuple_type_index} = true; +} +""") + UNPACK_SELF = "auto& self = reinterpret_cast(self_)->cdata;" PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\ @@ -589,6 +607,32 @@ def create_python_bindings(python_functions, has_self, is_module=False): python_binding_arguments.append(requires_grad_arg) return python_binding_arguments + def emit_namedtuple_return_type_def(declaration, next_index): + returns = declaration['returns'] + if len(returns) <= 1 or all(['field_name' not in x for x in returns]): + declaration['namedtuple_return_type'] = '' + return '', next_index + declaration['namedtuple_type_index'] = next_index + declaration['namedtuple_fields'] = '' + for x in returns: + # See Note [field_name versus name] + if 'field_name' not in x: + # When building on Windows, `PyStructSequence_UnnamedField` could not be + # resolved by the linker for some reason, which cause error in building: + # + # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol + # PyStructSequence_UnnamedField + # + # Thus, at this point in time, we do not support unnamed + # fields in namedtuple; you must either name all fields, + # or none of them. + raise ValueError("Unnamed field is not supported by codegen") + else: + declaration['namedtuple_fields'] += '{"' + x['field_name'] + '", ""}, ' + declaration['namedtuple_size'] = len(returns) + declaration['namedtuple_return_type'] = '&type{}, '.format(next_index) + return PY_RETURN_NAMEDTUPLE_DEF.substitute(declaration), next_index + 1 + def process_function(name, declarations): for declaration in declarations: declaration['python_binding_arguments'] = get_python_binding_arguments(declaration) @@ -601,11 +645,19 @@ def create_python_bindings(python_functions, has_self, is_module=False): 'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations), 'unpack_self': [], 'dispatch': [], + 'declare_namedtuple_return_types': '', } if has_self: env['unpack_self'] = [UNPACK_SELF] + # generate namedtuple type declare + next_index = 0 + for declaration in declarations: + typedef, next_index = emit_namedtuple_return_type_def(declaration, next_index) + env['declare_namedtuple_return_types'] += typedef + + # emit dispatch grouped = group_declarations(declarations) for i, dictionary in enumerate(grouped): signature = dictionary['signature'] @@ -629,6 +681,7 @@ def create_python_bindings(python_functions, has_self, is_module=False): tmpl = PY_VARIABLE_METHOD_NOARGS env['actuals'] = ['self'] env['flags'] = 'METH_NOARGS' + env['namedtuple_return_type'] = declarations[0]['namedtuple_return_type'] else: tmpl = PY_VARIABLE_METHOD_VARARGS env['flags'] = 'METH_VARARGS | METH_KEYWORDS' diff --git a/torch/_six.py b/torch/_six.py index ad50cf2..f6bdd39 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -135,3 +135,18 @@ if PY2: import __builtin__ as builtins elif PY3: import builtins + + +# The codes below is not copied from the six package, so the copyright +# declaration at the beginning does not apply. +# +# Copyright(c) PyTorch contributors +# + +def istuple(obj): + # Usually instances of PyStructSequence is also an instance of tuple + # but in some py2 environment it is not, so we have to manually check + # the name of the type to determine if it is a namedtupled returned + # by a pytorch operator. + t = type(obj) + return isinstance(obj, tuple) or t.__module__ == 'torch.return_types' diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 8d55062..28e60e2 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2649,9 +2649,10 @@ Example:: .. function:: max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor) -Returns the maximum value of each row of the :attr:`input` tensor in the given -dimension :attr:`dim`. The second return value is the index location of each -maximum value found (argmax). +Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum +value of each row of the :attr:`input` tensor in the given dimension +:attr:`dim`. And ``indices`` is the index location of each maximum value found +(argmax). If :attr:`keepdim` is ``True``, the output tensors are of the same size as :attr:`input` except in the dimension :attr:`dim` where they are of size 1. @@ -2673,7 +2674,7 @@ Example:: [ 1.5717, -0.9207, 0.1297, -1.8768], [-0.6172, 1.0036, -0.6060, -0.2432]]) >>> torch.max(a, 1) - (tensor([ 0.8475, 1.1949, 1.5717, 1.0036]), tensor([ 3, 0, 0, 1])) + torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) .. function:: max(input, other, out=None) -> Tensor @@ -4541,8 +4542,9 @@ add_docstr(torch.svd, r""" svd(input, some=True, compute_uv=True, out=None) -> (Tensor, Tensor, Tensor) -`U, S, V = torch.svd(A)` returns the singular value decomposition of a -real matrix `A` of size `(n x m)` such that :math:`A = USV^T`. +``svd(A)`` returns a namedtuple ``(U, S, V)`` which the singular value +decomposition of a input real matrix `A` of size `(n x m)` such that +:math:`A = USV^T`. `U` is of shape :math:`(n \times n)`. diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index dbb6cad..22cd8e2 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1,5 +1,5 @@ import torch -from torch._six import container_abcs +from torch._six import container_abcs, istuple import torch.testing import sys from itertools import product @@ -150,7 +150,7 @@ def get_analytical_jacobian(input, output): def _as_tuple(x): - if isinstance(x, tuple): + if istuple(x): return x elif isinstance(x, list): return tuple(x) diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index cc77775..9543159 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -26,6 +26,14 @@ inline PyObject* wrap(std::tuple tensors) { return r.release(); } +inline PyObject* wrap(PyTypeObject *type, std::tuple tensors) { + auto r = THPObjectPtr{PyStructSequence_New(type)}; + if (!r) throw python_error(); + PyStructSequence_SET_ITEM(r.get(), 0, wrap(std::get<0>(tensors))); + PyStructSequence_SET_ITEM(r.get(), 1, wrap(std::get<1>(tensors))); + return r.release(); +} + inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(3)}; if (!r) throw python_error(); @@ -35,6 +43,15 @@ inline PyObject* wrap(std::tuple tensors) { return r.release(); } +inline PyObject* wrap(PyTypeObject *type, std::tuple tensors) { + auto r = THPObjectPtr{PyStructSequence_New(type)}; + if (!r) throw python_error(); + PyStructSequence_SET_ITEM(r.get(), 0, wrap(std::get<0>(tensors))); + PyStructSequence_SET_ITEM(r.get(), 1, wrap(std::get<1>(tensors))); + PyStructSequence_SET_ITEM(r.get(), 2, wrap(std::get<2>(tensors))); + return r.release(); +} + inline PyObject* wrap(std::tuple tensors) { auto r = THPObjectPtr{PyTuple_New(4)}; if (!r) throw python_error(); diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index f18f954..9e8f9af 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -79,7 +80,7 @@ inline IValue toIValue(py::handle input) { AT_ERROR("sparse tensors not supported"); } return ten; - } else if (py::isinstance(input)) { + } else if (six::isTuple(input)) { py::tuple input_tuple = py::cast(input); Stack s; s.reserve(input_tuple.size()); diff --git a/torch/csrc/jit/python_arg_flatten.cpp b/torch/csrc/jit/python_arg_flatten.cpp index 95f7770..4061fc6 100644 --- a/torch/csrc/jit/python_arg_flatten.cpp +++ b/torch/csrc/jit/python_arg_flatten.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -31,7 +32,7 @@ py::object cast_handle_sequence(std::vector objs) { void flatten_rec(PyObject* obj, ParsedArgs& args) { auto& structure = args.desc.structure; - if (PyTuple_Check(obj)) { + if (six::isTuple(obj)) { structure.push_back(D::TupleOpen); for (auto item : py::reinterpret_borrow(obj)) flatten_rec(item.ptr(), args); diff --git a/torch/csrc/python_headers.h b/torch/csrc/python_headers.h index 390ce97..d1371d4 100644 --- a/torch/csrc/python_headers.h +++ b/torch/csrc/python_headers.h @@ -7,6 +7,7 @@ #undef _POSIX_C_SOURCE #include +#include #pragma pop_macro("_XOPEN_SOURCE") #pragma pop_macro("_POSIX_C_SOURCE") diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h new file mode 100644 index 0000000..0dd2e2f --- /dev/null +++ b/torch/csrc/utils/six.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace six { + +// Usually instances of PyStructSequence is also an instance of tuple +// but in some py2 environment it is not, so we have to manually check +// the name of the type to determine if it is a namedtupled returned +// by a pytorch operator. + +inline bool isTuple(pybind11::handle input) { + std::string m = pybind11::str(input.get_type().attr("__module__")); + return pybind11::isinstance(input) || m == "torch.return_types"; +} + +inline bool isTuple(PyObject* obj) { + return isTuple(pybind11::handle(obj)); +} + +} // namespace six -- 2.7.4