Return namedtuples from torch.* function with multiple return arguments for C++ opera...
authorXiang Gao <qasdfgtyuiop@gmail.com>
Tue, 22 Jan 2019 19:09:18 +0000 (11:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 22 Jan 2019 19:12:18 +0000 (11:12 -0800)
commitc5e1b469beab9edc7a0fb0ab9da1132b795de6c3
tree4744459aea393b9d18e83296e54d8de74614bfa1
parent1e19fd941f60d6296dacc11568126ab6e4c0619b
Return namedtuples from torch.* function with multiple return arguments for C++ operators (#15429)

Summary:
Partially fixes: https://github.com/pytorch/pytorch/issues/394

Implementation detail:

Codegen is modified to generate codes that looks like below:
```C++
static PyObject * THPVariable_svd(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None)",
  }, /*traceable=*/true);

  ParsedArgs<6> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  static PyStructSequence_Field fields0[] = {
    {"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
  };
  static PyStructSequence_Desc desc0 = {
    "torch.return_types.svd_out", nullptr,
    fields0, 3
  };
  static PyTypeObject type0;
  static bool namedtuple_type_initialized0 = false;
  if (!namedtuple_type_initialized0) {
    PyStructSequence_InitType(&type0, &desc0);
    namedtuple_type_initialized0 = true;
  }
  static PyStructSequence_Field fields1[] = {
    {"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
  };
  static PyStructSequence_Desc desc1 = {
    "torch.return_types.svd", nullptr,
    fields1, 3
  };
  static PyTypeObject type1;
  static bool namedtuple_type_initialized1 = false;
  if (!namedtuple_type_initialized1) {
    PyStructSequence_InitType(&type1, &desc1);
    namedtuple_type_initialized1 = true;
  }
  if (r.idx == 0) {
    if (r.isNone(3)) {
      return wrap(&type1, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2)));
    } else {
      auto results = r.tensorlist_n<3>(3);
      return wrap(&type0, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2), results[0], results[1], results[2]));
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
```
Types are defined as static member of `THPVariable_${op_name}` functions, and initialized at the first time the function is called.

When parsing function prototypes in `native_functions.yaml`, the parser will set the specified name as `field_name` when see things like `-> (Tensor t1, ...)`. These field names will be the field names of namedtuple. The class of namedtuples will be named `torch.return_types.${op_name}`.

In some python 2, `PyStructSequence` is not a subtype of tuple, so we have to create some functions to check if an object is a tuple or namedtuple for compatibility issue.

Operators in `native_functions.yaml` are changed such that only `max` and `svd` are generated as namedtuple. Tests are added for these two operators to see if the return value works as expected. Docs for these two ops are also updated to explicitly mention the return value is a namedtuple. More ops will be added in later PRs.

There is some issue with Windows build of linker unable to resolve `PyStructSequence_UnnamedField`, and some workaround is added to deal with this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15429

Differential Revision: D13709678

Pulled By: ezyang

fbshipit-source-id: 23a511c9436977098afc49374e9a748b6e30bccf
18 files changed:
aten/src/ATen/function_wrapper.py
aten/src/ATen/native/README.md
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native_parse.py
test/common_methods_invocations.py
test/test_autograd.py
test/test_torch.py
tools/autograd/derivatives.yaml
tools/autograd/gen_autograd.py
tools/autograd/gen_python_functions.py
torch/_six.py
torch/_torch_docs.py
torch/autograd/gradcheck.py
torch/csrc/autograd/utils/wrap_outputs.h
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/python_arg_flatten.cpp
torch/csrc/python_headers.h
torch/csrc/utils/six.h [new file with mode: 0644]