Allow passing dicts as trace inputs. (#18092)
authorEric Faust <efaust@fb.com>
Fri, 19 Apr 2019 06:48:59 +0000 (23:48 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 06:52:00 +0000 (23:52 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18092

Previously, tracing required all inputs to be either tensors,
or tuples of tensor. Now, we allow users to pass dicts as well.

Differential Revision: D14491795

fbshipit-source-id: 7a2df218e5d00f898d01fa5b9669f9d674280be3

12 files changed:
aten/src/ATen/core/type.cpp
test/cpp/jit/test_autodiff.h
test/test_jit.py
torch/autograd/function.py
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/python_tracer.cpp
torch/csrc/jit/python_tracer.h
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/init.cpp
torch/csrc/jit/tracer.cpp
torch/csrc/jit/tracer.h
torch/jit/__init__.py

index 9f7ffc3..6e24026 100644 (file)
@@ -287,6 +287,18 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
       }
     }
     return static_cast<TypePtr>(TupleType::create(elements));
+  } else if (t1->cast<DictType>() && t2->cast<DictType>()) {
+    auto dict1 = t1->cast<DictType>();
+    auto dict2 = t2->cast<DictType>();
+
+    auto unified_key = unifyTypes(dict1->getKeyType(), dict2->getKeyType());
+    auto unshaped_value1 = unshapedType(dict1->getValueType());
+    auto unshaped_value2 = unshapedType(dict2->getValueType());
+    auto unified_value = tryEitherIsTheSuperType(unshaped_value1, unshaped_value2);
+    if (!unified_key || !unified_value) {
+      return c10::nullopt;
+    }
+    return DictType::create(*unified_key, *unified_value);
   }
 
   return c10::nullopt;
index 00c7624..094da84 100644 (file)
@@ -60,9 +60,17 @@ variable_list get_grad_outputs(const variable_list& vars) {
 std::shared_ptr<Graph> trace(
     const ADTestSpec& test,
     const variable_list& vars_in) {
+  Stack input_vars = fmap<IValue>(vars_in);
+  std::vector<TypePtr> input_types;
+  input_types.reserve(input_vars.size());
+  for (auto i = 0; i < input_vars.size(); i++) {
+    input_types.push_back(TensorType::get());
+  }
+  auto input_typeptr = TupleType::create(std::move(input_types));
   std::shared_ptr<tracer::TracingState> state;
   Stack trace_stack_in;
-  std::tie(state, trace_stack_in) = tracer::enter(fmap<IValue>(vars_in));
+  std::tie(state, trace_stack_in) =
+      tracer::enter(tracer::TypedStack(input_vars, input_typeptr));
   variable_list trace_vars_in = fmap(
       trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); });
   auto trace_vars_out = test(trace_vars_in);
index b66a156..f20cfda 100644 (file)
@@ -12,6 +12,7 @@ from contextlib import contextmanager
 from itertools import product, chain
 import torch.jit.frontend
 from torch.autograd import Variable, Function
+from torch.autograd.function import _nested_map
 from torch.onnx import OperatorExportTypes
 from torch._six import inf, PY2, builtins, StringIO
 from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
@@ -19,7 +20,7 @@ from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
     freeze_rng_state, set_rng_seed, slowTest
 from common_nn import module_tests, new_module_tests, criterion_tests
 from textwrap import dedent
-from functools import wraps
+from functools import wraps, reduce
 import os
 import io
 import sys
@@ -535,9 +536,24 @@ class JitTestCase(TestCase):
         if input_tensors is None:
             input_tensors = reference_tensors
 
+        def do_input_map(fn, input):
+            return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
+
+        def flatten_inputs(inputs):
+            def input_reduce(input, fn, acc):
+                if isinstance(input, torch.Tensor):
+                    fn(input, acc)
+                elif isinstance(input, dict):
+                    reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
+                else:
+                    reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
+                return acc
+            return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
+
         nograd_inputs = reference_tensors
         if inputs_require_grads:
-            recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
+            recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
         else:
             recording_inputs = reference_tensors
 
@@ -558,12 +574,12 @@ class JitTestCase(TestCase):
         # test single grad case
         outputs = func(*recording_inputs)
         if inputs_require_grads:
-            grads = torch.autograd.grad(allSum(outputs), recording_inputs,
+            grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
                                         allow_unused=allow_unused)
 
         outputs_ge = ge(*recording_inputs)
         if inputs_require_grads:
-            grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
+            grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
                                            allow_unused=allow_unused)
         self.assertEqual(outputs, outputs_ge)
         if inputs_require_grads:
@@ -574,25 +590,25 @@ class JitTestCase(TestCase):
         outputs = func(*recording_inputs)
         l1 = allSum(outputs)
         if inputs_require_grads:
-            grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
+            grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
                                         allow_unused=allow_unused)
         if inputs_require_grads:
             l2 = (allSum(grads) * l1)
-            grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
+            grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
 
         if inputs_require_grads:
-            recording_inputs = [Variable(t, requires_grad=True)
-                                for t in reference_tensors]
+            recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
+            flattened_recording_inputs = flatten_inputs(recording_inputs)
 
         outputs_ge = ge(*recording_inputs)
         l1_ge = allSum(outputs_ge)
         if inputs_require_grads:
             grads_ge = torch.autograd.grad(
-                l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
+                l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
 
         if inputs_require_grads:
             l2_ge = (allSum(grads_ge) * l1_ge)
-            grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
+            grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
 
         self.assertEqual(outputs, outputs_ge)
         if inputs_require_grads:
@@ -1580,8 +1596,6 @@ graph(%x : Tensor,
 
         self.checkTrace(fn, (torch.randn(2, 2),))
 
-    # TODO: implement
-    @unittest.expectedFailure
     def test_input_flatten(self):
         """Check that inputs to traced functions are flattened"""
 
@@ -1592,6 +1606,66 @@ graph(%x : Tensor,
         inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
         self.checkTrace(fn, inputs)
 
+    def test_input_dict_empty(self):
+        def test(d):
+            pass
+
+        with self.assertRaises(RuntimeError):
+            self.checkTrace(test, {})
+
+    def test_input_dict_flattens(self):
+        class Test(torch.nn.Module):
+            def forward(self, d):
+                return d['x'] + d['y']
+
+        inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
+        module = torch.jit.trace(Test(), inputs)
+        FileCheck().check('aten::values').check('prim::ListUnpack').run(str(module.graph))
+
+    def test_input_dict_flattens_recursive(self):
+        class Test(torch.nn.Module):
+            def forward(self, d):
+                # Use both to avoid getting optimized away
+                a = d['x'][0]
+                b, c = d['y']
+                return a + b
+
+        inputs = {'x': (torch.rand(2, 2), torch.rand(2, 2)), 'y': (torch.ones(1, 1), torch.ones(2, 1))}
+        module = torch.jit.trace(Test(), inputs)
+        FileCheck().check('aten::values') \
+                   .check('prim::ListUnpack') \
+                   .check_count('prim::TupleUnpack', 2) \
+                   .run(str(module.graph))
+
+    def test_input_dict_checkTrace_mut(self):
+        def test(d):
+            d['x'].tanh_()
+            return d['x']
+        inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
+        self.checkTrace(test, (inputs,), inputs_require_grads=False)
+
+    def test_input_dict_unify(self):
+        def test(d):
+            return d['int'], d['float']
+        inputs = {'int': torch.ones((2, 2), dtype=torch.int32),
+                  'float': torch.ones((2, 2), dtype=torch.float32)}
+        self.checkTrace(test, (inputs,), inputs_require_grads=False)
+
+    def test_input_tuple_of_dicts(self):
+        def test(t):
+            d = t[0]
+            return d['x']['y']
+        inputs = {'x': {'y': torch.rand(2, 3)}}
+        self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
+
+    def test_input_dict_of_dicts(self):
+        def test(d):
+            return d['x']['y']
+        nested_input = {'y': torch.rand(2, 3)}
+        unified_nested = {'y': torch.rand(3, 2)}
+        inputs = {'x': nested_input, 'force_unify': unified_nested}
+        self.checkTrace(test, (inputs,), allow_unused=True)
+
     # TODO: adapt to a GraphExecutor test
     @unittest.skip("Need to instrument GraphExecutors a bit more")
     def test_flags(self):
index 0fe2c07..da59648 100644 (file)
@@ -255,6 +255,8 @@ def _nested_map(condition, fn, condition_msg=None):
             return None
         elif isinstance(obj, (list, tuple)):
             return type(obj)(_map(x) for x in obj)
+        elif isinstance(obj, dict):
+            return {x : _map(obj[x]) for x in obj}
         else:
             raise ValueError("Auto nesting doesn't know how to process "
                              "an input object of type " + torch.typename(obj) +
@@ -284,6 +286,11 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None,
             for o in obj:
                 for var in _iter(o):
                     yield var
+        elif isinstance(obj, dict):
+            # We only accept primitive key types, so we needn't inspect them
+            for o in obj.values():
+                for var in _iter(o):
+                    yield var
         elif allow_unknown:
             yield obj
         else:
index 94a3684..ff7dac5 100644 (file)
@@ -35,24 +35,85 @@ namespace jit {
 // that is confusing to display to the end user since it always reports
 // locations in libtorch code rather than user code.
 
-inline IValue toIValue(py::handle input) {
+using tracer::TypedStack;
+struct TypedIValue : public std::pair<IValue, TypePtr> {
+  using pair::pair;
+
+  IValue& ivalue() {
+    return this->first;
+  }
+  TypePtr& type() {
+    return this->second;
+  }
+};
+
+inline TypedIValue toDictKeyIValue(py::handle key) {
+  if (py::isinstance<py::str>(key)) {
+    return TypedIValue(ConstantString::create(py::cast<std::string>(key)),
+                       StringType::create());
+  } else if (PyLong_Check(key.ptr())) {
+    return TypedIValue(py::cast<int64_t>(key), IntType::create());
+  } else if (PyFloat_Check(key.ptr())) {
+    return TypedIValue(py::cast<double>(key), FloatType::create());
+  } else {
+    AT_ERROR("Dictionary inputs may only have string, int, or float keys");
+  }
+}
+
+inline TypedIValue toTypedIValue(py::handle input) {
   if (THPVariable_Check(input.ptr())) {
     auto ten = py::cast<at::Tensor>(input);
     if (ten.is_sparse()) {
       AT_ERROR("sparse tensors not supported");
     }
-    return ten;
+    return TypedIValue(ten, CompleteTensorType::create(ten));
   } else if (six::isTuple(input)) {
     py::tuple input_tuple = py::cast<py::tuple>(input);
     Stack s;
+    std::vector<TypePtr> t;
     s.reserve(input_tuple.size());
+    t.reserve(input_tuple.size());
     for (py::handle elem : input_tuple) {
-      s.push_back(toIValue(elem));
+      auto info = toTypedIValue(elem);
+      s.push_back(info.first);
+      t.push_back(info.second);
+    }
+    return TypedIValue(Tuple::create(s), TupleType::create(t));
+  } else if (PyDict_Check(input.ptr())) {
+    // Check to make sure we can generate useful input/output types
+    auto dict = py::cast<py::dict>(input);
+    at::ivalue::UnorderedMap elems;
+
+    size_t len = py::len(dict);
+    if (!len) {
+      AT_ERROR("Dictionary inputs must have entries.");
     }
-    return Tuple::create(s);
+    elems.reserve(len);
+
+    TypePtr keyType = nullptr;
+    TypePtr valueType = nullptr;
+    for (auto entry : dict) {
+      auto keyInfo = toDictKeyIValue(entry.first);
+      auto valInfo = toTypedIValue(entry.second);
+      if (!keyType) {
+        keyType = keyInfo.second;
+        valueType = valInfo.second;
+      } else {
+        auto unifiedKey = unifyTypes(keyType, keyInfo.second);
+        auto unifiedValue = unifyTypes(valueType, valInfo.second);
+        if (!unifiedKey || !unifiedValue) {
+          AT_ERROR("Dictionary inputs to traced functions must have consistent type");
+        }
+        keyType = *unifiedKey;
+        valueType = *unifiedValue;
+      }
+      elems.insert(std::make_pair(keyInfo.first, valInfo.first));
+    }
+    return TypedIValue(at::ivalue::GenericDict::create(std::move(elems)),
+                       DictType::create(keyType, valueType));
   } else {
     throw std::runtime_error(c10::str(
-        "Only tensors and (possibly nested) tuples of tensors are supported ",
+        "Only tensors and (possibly nested) tuples of tensors or dicts are supported ",
         "as inputs or outputs of traced functions",
         ", but instead got value of type ",
         py::str(input.get_type().attr("__name__")),
@@ -62,10 +123,19 @@ inline IValue toIValue(py::handle input) {
   }
 }
 
+inline IValue toIValue(py::handle input) {
+    return toTypedIValue(input).ivalue();
+}
+
 inline Stack toStack(const py::tuple& inputs) {
   return toIValue(inputs).toTuple()->elements();
 }
 
+inline TypedStack toTypedStack(const py::tuple& inputs) {
+  auto info = toTypedIValue(inputs);
+  return TypedStack(info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>());
+}
+
 inline IValue toIValue(
     py::handle obj,
     const TypePtr& type,
index 45365b9..5886765 100644 (file)
@@ -38,7 +38,7 @@ std::string getPythonInterpreterStackTrace() {
 
 std::shared_ptr<torch::jit::Graph> createGraphByTracing(
     const py::function& func,
-    Stack trace_inputs,
+    TypedStack trace_inputs,
     const py::function& var_name_lookup_fn,
     bool force_outplace,
     const c10::optional<size_t>& num_real_inputs) {
@@ -145,7 +145,7 @@ void initPythonTracerBindings(PyObject* module) {
 
   m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
   m.def("_tracer_enter", [](py::args trace_inputs) {
-    return tracer::enter(toStack(trace_inputs));
+    return tracer::enter(toTypedStack(trace_inputs));
   });
   m.def("_tracer_exit", [](py::tuple var_outputs) {
     tracer::exit(toStack(var_outputs));
index ff68499..74b6c70 100644 (file)
@@ -21,7 +21,7 @@ Node* preRecordPythonTrace(
 
 std::shared_ptr<Graph> createGraphByTracing(
     const py::function& func,
-    Stack inputs,
+    TypedStack inputs,
     const py::function& var_name_lookup_fn,
     bool force_outplace,
     const c10::optional<size_t>& num_real_inputs = c10::nullopt);
index f2edf87..a579556 100644 (file)
@@ -1540,15 +1540,33 @@ int dictKeys(Stack& stack) {
   return 0;
 }
 
-int dictValues(Stack& stack) {
-  auto dict = pop(stack).toGenericDictRef();
-  std::vector<IValue> values;
+template <typename Elem>
+std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) {
+  std::vector<Elem> values;
   values.reserve(dict.size());
   for (auto item : dict) {
-    values.push_back(item.second);
+    values.push_back(item.second.to<Elem>());
   }
-  push(stack, IValue(values));
-  return 0;
+  return values;
+}
+
+Operation dictValues(const Node* n) {
+  auto outputType = n->output()->type()->expect<ListType>();
+  return [=](Stack& stack) -> int {
+    auto dict = pop(stack).toGenericDictRef();
+    if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
+      push(stack, makeListForDictValues<at::Tensor>(dict));
+    } else if (outputType->getElementType() == IntType::get()) {
+      push(stack, makeListForDictValues<int64_t>(dict));
+    } else if (outputType->getElementType() == FloatType::get()) {
+      push(stack, makeListForDictValues<double>(dict));
+    } else if (outputType->getElementType() == BoolType::get()) {
+      push(stack, makeListForDictValues<bool>(dict));
+    } else {
+      push(stack, makeListForDictValues<IValue>(dict));
+    }
+    return 0;
+  };
 }
 
 int dictIndex(Stack& stack) {
index 96d1e5b..18f5426 100644 (file)
@@ -939,13 +939,19 @@ void initJitScriptBindings(PyObject* module) {
             // this was ensured in python before calling this function
             std::vector<Slot> parameters;
             gatherParametersAndBuffers(parameters, *self);
-            Stack inputs = toStack(input_tuple);
-            for (const Slot& param : parameters) {
-              inputs.emplace_back(param.value());
+            auto typed_inputs = toTypedStack(input_tuple);
+            if (parameters.size() > 0) {
+              auto inputs = typed_inputs.stack();
+              auto input_types = typed_inputs.types()->elements().vec();
+              for (const Slot& param : parameters) {
+                inputs.emplace_back(param.value());
+                input_types.push_back(incompleteInferTypeFrom(param.value()));
+              }
+              typed_inputs = TypedStack(inputs, TupleType::create(input_types));
             }
             auto graph = tracer::createGraphByTracing(
                 func,
-                inputs,
+                typed_inputs,
                 var_lookup_fn,
                 force_outplace,
                 input_tuple.size());
index 4680ae2..785dd89 100644 (file)
@@ -201,7 +201,7 @@ Value* getNestedOutputTrace(
 // Start tracing, treating 'inputs' as inputs to the trace, which can be
 // varied on subsequent invocations of the trace.  Any other variables
 // will be treated as constants.
-std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
+std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs) {
   if (isTracing()) {
     AT_ERROR("Tracing can't be nested");
   }
@@ -234,17 +234,34 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
         elems[i] = add_input(elems[i], elem_types[i], elem_values[i]);
       }
       return Tuple::create(std::move(elems));
+    } else if (auto dict_type = type->cast<DictType>()) {
+      auto elem_pairs = input.toGenericDict()->elements();
+      auto unpack_to_list = state->graph->insert(aten::values, {value});
+      auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size());
+      auto unpack_node = state->graph->insertNode(list_unpack);
+      auto elem_values = unpack_node->outputs();
+
+      AT_ASSERT(elem_pairs.size() == elem_values.size());
+
+      size_t i = 0;
+      for (const auto &pair : elem_pairs) {
+        elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]);
+      }
+
+      return c10::ivalue::GenericDict::create(std::move(elem_pairs));
     } else {
       AT_ERROR(
-          "Only tensors or tuples of tensors can be inputs to traced functions. Got ",
-          type);
+          "Only tensors or (possibly nested) dict or tuples of tensors can be "
+          "inputs to traced functions. Got ", type);
     }
   };
-  for (IValue& input : inputs) {
+  size_t i = 0;
+  auto input_types = inputs.types()->elements();
+  for (IValue& input : inputs.stack()) {
     input = add_input(
-        input, incompleteInferTypeFrom(input), state->graph->addInput());
+        input, input_types[i++], state->graph->addInput());
   }
-  return std::make_pair(state, inputs);
+  return std::make_pair(state, inputs.stack());
 }
 
 // Exit a trace, treating 'outputs' as the outputs of the trace.  These
index 09ea106..cca4a95 100644 (file)
@@ -65,7 +65,30 @@ TORCH_API Value* getNestedOutputTrace(
     const std::shared_ptr<TracingState>& state,
     const IValue& iv);
 
-TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs);
+struct TypedStack : public std::pair<Stack, TupleTypePtr>
+{
+  using pair::pair;
+
+  // NB: The inherited default constructor gives nullptr for |type|,
+  //     so we provide a saner one.
+  TypedStack()
+    : pair({}, TupleType::create({}))
+  {}
+
+  Stack& stack() {
+    return this->first;
+  }
+  TupleTypePtr& types() {
+    return this->second;
+  }
+  size_t size() {
+    auto s = stack().size();
+    AT_ASSERT(s == types()->elements().size());
+    return s;
+  }
+};
+
+TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs);
 
 TORCH_API void exit(const Stack& outputs);
 
index 5a6ffcf..f8d82ee 100644 (file)
@@ -656,7 +656,7 @@ def trace(func,
         return func
     executor_options = {'optimize': bool(optimize)}
     # Special case for common case of passing a single Tensor
-    if isinstance(example_inputs, torch.Tensor):
+    if isinstance(example_inputs, (torch.Tensor, dict)):
         example_inputs = (example_inputs,)
     # done primarily so that weird iterables fail here and not pybind11 code
     elif not isinstance(example_inputs, tuple):