Optional ScalarType support for native functions & JIT (#15154)
authorTugrul Ates <ates@fb.com>
Wed, 19 Dec 2018 18:40:48 +0000 (10:40 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 18:45:35 +0000 (10:45 -0800)
commit560530aeecdae8c37d47c8ee7967c8583843a0bf
tree0157bd66f80452b231f2380f9a08bdb5a57f516a
parent54d4fe3f49d781030cb4c3b7ec3944fb2cf7aae3
Optional ScalarType support for native functions & JIT (#15154)

Summary:
For #6593 and #9515

This completes the support for optional<ScalarType> in native, JIT and autograd.

Note: Mostly following the existing implementation for optional<Scalar> that was added in https://github.com/pytorch/pytorch/pull/12582.

This PR introduces a way to make functions accept an optional dtype and it will unblock #9515 by allowing the `dtype` param for type promotion interface:
```
func: name(inputs, *, ScalarType? dtype=None, Casting casting=same_kind)
```

An alternative approach could have been using `ScalarType::Undefined` for the same purpose but without optional, though it would have been a bit hacky.
```
func: name(inputs, *, ScalarType dtype=Undefined, Casting casting=same_kind)
```

Here's an example use of this in action: https://github.com/pytorch/pytorch/pull/15133/commits/971f69eac69101955ed90078b44dab975d37a4f7

There are already a bunch of native functions that were getting optional `dtype` through function overloading. https://github.com/pytorch/pytorch/pull/15133 is the attempt to migrate all of those. I will send those changes separately after this since some functions (e.g. sum) need quite a bit of change in the codebase. See the commits over there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15154

Differential Revision: D13457760

Pulled By: tugrulates

fbshipit-source-id: 706134f0bd578683edd416b96329b49a1ba8ab48
tools/autograd/gen_python_functions.py
tools/jit/gen_jit_dispatch.py
tools/jit/templates/register_aten_ops.cpp
torch/csrc/jit/passes/to_batch.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h
torch/csrc/utils/python_arg_parser.cpp