Optional ScalarType support for native functions & JIT (#15154)
authorTugrul Ates <ates@fb.com>
Wed, 19 Dec 2018 18:40:48 +0000 (10:40 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 18:45:35 +0000 (10:45 -0800)
Summary:
For #6593 and #9515

This completes the support for optional<ScalarType> in native, JIT and autograd.

Note: Mostly following the existing implementation for optional<Scalar> that was added in https://github.com/pytorch/pytorch/pull/12582.

This PR introduces a way to make functions accept an optional dtype and it will unblock #9515 by allowing the `dtype` param for type promotion interface:
```
func: name(inputs, *, ScalarType? dtype=None, Casting casting=same_kind)
```

An alternative approach could have been using `ScalarType::Undefined` for the same purpose but without optional, though it would have been a bit hacky.
```
func: name(inputs, *, ScalarType dtype=Undefined, Casting casting=same_kind)
```

Here's an example use of this in action: https://github.com/pytorch/pytorch/pull/15133/commits/971f69eac69101955ed90078b44dab975d37a4f7

There are already a bunch of native functions that were getting optional `dtype` through function overloading. https://github.com/pytorch/pytorch/pull/15133 is the attempt to migrate all of those. I will send those changes separately after this since some functions (e.g. sum) need quite a bit of change in the codebase. See the commits over there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15154

Differential Revision: D13457760

Pulled By: tugrulates

fbshipit-source-id: 706134f0bd578683edd416b96329b49a1ba8ab48

tools/autograd/gen_python_functions.py
tools/jit/gen_jit_dispatch.py
tools/jit/templates/register_aten_ops.cpp
torch/csrc/jit/passes/to_batch.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h
torch/csrc/utils/python_arg_parser.cpp

index 1a0fc7f..b1e794e 100644 (file)
@@ -251,7 +251,6 @@ def create_python_bindings(python_functions, has_self, is_module=False):
         'const Type &': 'scalartype',
         'const THPLayout &': 'layout',
         'const Device &': 'device',
-        'optional<ScalarType>': 'scalartypeOptional',
         'int64_t': 'toInt64',
         'bool': 'toBool',
         'double': 'toDouble',
index 352d60b..66371ad 100644 (file)
@@ -54,6 +54,7 @@ TYPE_MAP = {
     'Layout': 'Layout',
     'Device': 'Device',
     'ScalarType': 'ScalarType',
+    'ScalarType?': 'ScalarType?',
     'int64_t': 'int',
     'double': 'float',
     'bool': 'bool',
@@ -84,6 +85,7 @@ FROM_IVALUE = {
     'Scalar': '{}.toScalar()',
     'Scalar?': '{}.toOptional<Scalar>()',
     'ScalarType': '{}.toScalarType()',
+    'ScalarType?': '{}.toOptional<ScalarType>()',
     'Tensor': '{}.toTensor()',
     'TensorList': '{}.toTensorList()->elements()',
     'bool': '{}.toBool()',
index 69d724d..b6ca3d6 100644 (file)
@@ -37,6 +37,7 @@ namespace torch { namespace jit {
 using autograd::Variable;
 using autograd::variable_list;
 using at::Scalar;
+using at::ScalarType;
 using at::Tensor;
 using at::TensorOptions;
 using at::DeviceGuard;
index 0411569..2451304 100644 (file)
@@ -483,6 +483,7 @@ void ToBatch::toBatch(Block* block, Block* res_block) {
     else if(n->kind().is_prim()){
       switch(n->kind()){
         case prim::Constant:
+        case prim::None:
           visitConstant(n, block, res_block);
           break;
         case prim::NumToTensor:
index f86ae6d..f2fc0c0 100644 (file)
@@ -107,6 +107,17 @@ void addInputs(Node *n, const char * name, at::Layout value) {
 void addInputs(Node *n, const char * name, at::ScalarType value) {
   detail::genericAddInput(n, static_cast<int64_t>(value));
 }
+void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value)  {
+  if(value) {
+    detail::genericAddInput(n, static_cast<int64_t>(*value));
+  } else {
+    Graph * g = n->owningGraph();
+    Value* none =
+        g->insertNode(g->createNone(IntType::get()))
+            ->output();
+    n->addInput(none);
+  }
+}
 
 void addInputs(Node *n, const char * name, at::TensorList value) {
   Graph *g = n->owningGraph();
index 15de8c7..cc4bb6b 100644 (file)
@@ -214,6 +214,7 @@ TORCH_API void addInputs(Node *n, const char * name, const at::TensorOptions& va
 TORCH_API void addInputs(Node *n, const char * name, at::Device value);
 TORCH_API void addInputs(Node *n, const char * name, at::Layout value);
 TORCH_API void addInputs(Node *n, const char * name, at::ScalarType value);
+TORCH_API void addInputs(Node *n, const char * name, const c10::optional<at::ScalarType>& value);
 TORCH_API void addInputs(Node *n, const char * name, at::Generator * value);
 
 template<size_t N>
index 8a13094..6c3de1e 100644 (file)
@@ -27,7 +27,6 @@ static std::unordered_map<std::string, ParameterType> type_map = {
   {"Storage", ParameterType::STORAGE},
   {"PyObject*", ParameterType::PYOBJECT},
   {"ScalarType", ParameterType::SCALARTYPE},
-  {"optional<ScalarType>", ParameterType::SCALARTYPE},
   {"Layout", ParameterType::LAYOUT},
   {"Device", ParameterType::DEVICE},
   {"std::string", ParameterType::STRING},