From 593bb145ceea548478f83fd820de44b95d375108 Mon Sep 17 00:00:00 2001 From: Eric Faust Date: Thu, 18 Apr 2019 23:48:59 -0700 Subject: [PATCH] Allow passing dicts as trace inputs. (#18092) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18092 Previously, tracing required all inputs to be either tensors, or tuples of tensor. Now, we allow users to pass dicts as well. Differential Revision: D14491795 fbshipit-source-id: 7a2df218e5d00f898d01fa5b9669f9d674280be3 --- aten/src/ATen/core/type.cpp | 12 +++++ test/cpp/jit/test_autodiff.h | 10 +++- test/test_jit.py | 98 +++++++++++++++++++++++++++++++----- torch/autograd/function.py | 7 +++ torch/csrc/jit/pybind_utils.h | 80 +++++++++++++++++++++++++++-- torch/csrc/jit/python_tracer.cpp | 4 +- torch/csrc/jit/python_tracer.h | 2 +- torch/csrc/jit/register_prim_ops.cpp | 30 ++++++++--- torch/csrc/jit/script/init.cpp | 14 ++++-- torch/csrc/jit/tracer.cpp | 29 ++++++++--- torch/csrc/jit/tracer.h | 25 ++++++++- torch/jit/__init__.py | 2 +- 12 files changed, 274 insertions(+), 39 deletions(-) diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 9f7ffc3..6e24026 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -287,6 +287,18 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { } } return static_cast(TupleType::create(elements)); + } else if (t1->cast() && t2->cast()) { + auto dict1 = t1->cast(); + auto dict2 = t2->cast(); + + auto unified_key = unifyTypes(dict1->getKeyType(), dict2->getKeyType()); + auto unshaped_value1 = unshapedType(dict1->getValueType()); + auto unshaped_value2 = unshapedType(dict2->getValueType()); + auto unified_value = tryEitherIsTheSuperType(unshaped_value1, unshaped_value2); + if (!unified_key || !unified_value) { + return c10::nullopt; + } + return DictType::create(*unified_key, *unified_value); } return c10::nullopt; diff --git a/test/cpp/jit/test_autodiff.h b/test/cpp/jit/test_autodiff.h index 00c7624..094da84 100644 --- a/test/cpp/jit/test_autodiff.h +++ b/test/cpp/jit/test_autodiff.h @@ -60,9 +60,17 @@ variable_list get_grad_outputs(const variable_list& vars) { std::shared_ptr trace( const ADTestSpec& test, const variable_list& vars_in) { + Stack input_vars = fmap(vars_in); + std::vector input_types; + input_types.reserve(input_vars.size()); + for (auto i = 0; i < input_vars.size(); i++) { + input_types.push_back(TensorType::get()); + } + auto input_typeptr = TupleType::create(std::move(input_types)); std::shared_ptr state; Stack trace_stack_in; - std::tie(state, trace_stack_in) = tracer::enter(fmap(vars_in)); + std::tie(state, trace_stack_in) = + tracer::enter(tracer::TypedStack(input_vars, input_typeptr)); variable_list trace_vars_in = fmap( trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); }); auto trace_vars_out = test(trace_vars_in); diff --git a/test/test_jit.py b/test/test_jit.py index b66a156..f20cfda 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12,6 +12,7 @@ from contextlib import contextmanager from itertools import product, chain import torch.jit.frontend from torch.autograd import Variable, Function +from torch.autograd.function import _nested_map from torch.onnx import OperatorExportTypes from torch._six import inf, PY2, builtins, StringIO from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ @@ -19,7 +20,7 @@ from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ freeze_rng_state, set_rng_seed, slowTest from common_nn import module_tests, new_module_tests, criterion_tests from textwrap import dedent -from functools import wraps +from functools import wraps, reduce import os import io import sys @@ -535,9 +536,24 @@ class JitTestCase(TestCase): if input_tensors is None: input_tensors = reference_tensors + def do_input_map(fn, input): + return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input) + + def flatten_inputs(inputs): + def input_reduce(input, fn, acc): + if isinstance(input, torch.Tensor): + fn(input, acc) + elif isinstance(input, dict): + reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc) + else: + reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc) + return acc + return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), [])) + nograd_inputs = reference_tensors if inputs_require_grads: - recording_inputs = [t.clone().requires_grad_() for t in reference_tensors] + recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors) + flattened_recording_inputs = flatten_inputs(recording_inputs) else: recording_inputs = reference_tensors @@ -558,12 +574,12 @@ class JitTestCase(TestCase): # test single grad case outputs = func(*recording_inputs) if inputs_require_grads: - grads = torch.autograd.grad(allSum(outputs), recording_inputs, + grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs, allow_unused=allow_unused) outputs_ge = ge(*recording_inputs) if inputs_require_grads: - grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs, + grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs, allow_unused=allow_unused) self.assertEqual(outputs, outputs_ge) if inputs_require_grads: @@ -574,25 +590,25 @@ class JitTestCase(TestCase): outputs = func(*recording_inputs) l1 = allSum(outputs) if inputs_require_grads: - grads = torch.autograd.grad(l1, recording_inputs, create_graph=True, + grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused) if inputs_require_grads: l2 = (allSum(grads) * l1) - grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused) + grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused) if inputs_require_grads: - recording_inputs = [Variable(t, requires_grad=True) - for t in reference_tensors] + recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors) + flattened_recording_inputs = flatten_inputs(recording_inputs) outputs_ge = ge(*recording_inputs) l1_ge = allSum(outputs_ge) if inputs_require_grads: grads_ge = torch.autograd.grad( - l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused) + l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused) if inputs_require_grads: l2_ge = (allSum(grads_ge) * l1_ge) - grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused) + grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused) self.assertEqual(outputs, outputs_ge) if inputs_require_grads: @@ -1580,8 +1596,6 @@ graph(%x : Tensor, self.checkTrace(fn, (torch.randn(2, 2),)) - # TODO: implement - @unittest.expectedFailure def test_input_flatten(self): """Check that inputs to traced functions are flattened""" @@ -1592,6 +1606,66 @@ graph(%x : Tensor, inputs = (torch.randn(1), (torch.randn(1), torch.randn(1))) self.checkTrace(fn, inputs) + def test_input_dict_empty(self): + def test(d): + pass + + with self.assertRaises(RuntimeError): + self.checkTrace(test, {}) + + def test_input_dict_flattens(self): + class Test(torch.nn.Module): + def forward(self, d): + return d['x'] + d['y'] + + inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)} + module = torch.jit.trace(Test(), inputs) + FileCheck().check('aten::values').check('prim::ListUnpack').run(str(module.graph)) + + def test_input_dict_flattens_recursive(self): + class Test(torch.nn.Module): + def forward(self, d): + # Use both to avoid getting optimized away + a = d['x'][0] + b, c = d['y'] + return a + b + + inputs = {'x': (torch.rand(2, 2), torch.rand(2, 2)), 'y': (torch.ones(1, 1), torch.ones(2, 1))} + module = torch.jit.trace(Test(), inputs) + FileCheck().check('aten::values') \ + .check('prim::ListUnpack') \ + .check_count('prim::TupleUnpack', 2) \ + .run(str(module.graph)) + + def test_input_dict_checkTrace_mut(self): + def test(d): + d['x'].tanh_() + return d['x'] + inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)} + self.checkTrace(test, (inputs,), inputs_require_grads=False) + + def test_input_dict_unify(self): + def test(d): + return d['int'], d['float'] + inputs = {'int': torch.ones((2, 2), dtype=torch.int32), + 'float': torch.ones((2, 2), dtype=torch.float32)} + self.checkTrace(test, (inputs,), inputs_require_grads=False) + + def test_input_tuple_of_dicts(self): + def test(t): + d = t[0] + return d['x']['y'] + inputs = {'x': {'y': torch.rand(2, 3)}} + self.checkTrace(test, ((inputs, inputs),), allow_unused=True) + + def test_input_dict_of_dicts(self): + def test(d): + return d['x']['y'] + nested_input = {'y': torch.rand(2, 3)} + unified_nested = {'y': torch.rand(3, 2)} + inputs = {'x': nested_input, 'force_unify': unified_nested} + self.checkTrace(test, (inputs,), allow_unused=True) + # TODO: adapt to a GraphExecutor test @unittest.skip("Need to instrument GraphExecutors a bit more") def test_flags(self): diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 0fe2c07..da59648 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -255,6 +255,8 @@ def _nested_map(condition, fn, condition_msg=None): return None elif isinstance(obj, (list, tuple)): return type(obj)(_map(x) for x in obj) + elif isinstance(obj, dict): + return {x : _map(obj[x]) for x in obj} else: raise ValueError("Auto nesting doesn't know how to process " "an input object of type " + torch.typename(obj) + @@ -284,6 +286,11 @@ def _iter_filter(condition, allow_unknown=False, condition_msg=None, for o in obj: for var in _iter(o): yield var + elif isinstance(obj, dict): + # We only accept primitive key types, so we needn't inspect them + for o in obj.values(): + for var in _iter(o): + yield var elif allow_unknown: yield obj else: diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 94a3684..ff7dac5 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -35,24 +35,85 @@ namespace jit { // that is confusing to display to the end user since it always reports // locations in libtorch code rather than user code. -inline IValue toIValue(py::handle input) { +using tracer::TypedStack; +struct TypedIValue : public std::pair { + using pair::pair; + + IValue& ivalue() { + return this->first; + } + TypePtr& type() { + return this->second; + } +}; + +inline TypedIValue toDictKeyIValue(py::handle key) { + if (py::isinstance(key)) { + return TypedIValue(ConstantString::create(py::cast(key)), + StringType::create()); + } else if (PyLong_Check(key.ptr())) { + return TypedIValue(py::cast(key), IntType::create()); + } else if (PyFloat_Check(key.ptr())) { + return TypedIValue(py::cast(key), FloatType::create()); + } else { + AT_ERROR("Dictionary inputs may only have string, int, or float keys"); + } +} + +inline TypedIValue toTypedIValue(py::handle input) { if (THPVariable_Check(input.ptr())) { auto ten = py::cast(input); if (ten.is_sparse()) { AT_ERROR("sparse tensors not supported"); } - return ten; + return TypedIValue(ten, CompleteTensorType::create(ten)); } else if (six::isTuple(input)) { py::tuple input_tuple = py::cast(input); Stack s; + std::vector t; s.reserve(input_tuple.size()); + t.reserve(input_tuple.size()); for (py::handle elem : input_tuple) { - s.push_back(toIValue(elem)); + auto info = toTypedIValue(elem); + s.push_back(info.first); + t.push_back(info.second); + } + return TypedIValue(Tuple::create(s), TupleType::create(t)); + } else if (PyDict_Check(input.ptr())) { + // Check to make sure we can generate useful input/output types + auto dict = py::cast(input); + at::ivalue::UnorderedMap elems; + + size_t len = py::len(dict); + if (!len) { + AT_ERROR("Dictionary inputs must have entries."); } - return Tuple::create(s); + elems.reserve(len); + + TypePtr keyType = nullptr; + TypePtr valueType = nullptr; + for (auto entry : dict) { + auto keyInfo = toDictKeyIValue(entry.first); + auto valInfo = toTypedIValue(entry.second); + if (!keyType) { + keyType = keyInfo.second; + valueType = valInfo.second; + } else { + auto unifiedKey = unifyTypes(keyType, keyInfo.second); + auto unifiedValue = unifyTypes(valueType, valInfo.second); + if (!unifiedKey || !unifiedValue) { + AT_ERROR("Dictionary inputs to traced functions must have consistent type"); + } + keyType = *unifiedKey; + valueType = *unifiedValue; + } + elems.insert(std::make_pair(keyInfo.first, valInfo.first)); + } + return TypedIValue(at::ivalue::GenericDict::create(std::move(elems)), + DictType::create(keyType, valueType)); } else { throw std::runtime_error(c10::str( - "Only tensors and (possibly nested) tuples of tensors are supported ", + "Only tensors and (possibly nested) tuples of tensors or dicts are supported ", "as inputs or outputs of traced functions", ", but instead got value of type ", py::str(input.get_type().attr("__name__")), @@ -62,10 +123,19 @@ inline IValue toIValue(py::handle input) { } } +inline IValue toIValue(py::handle input) { + return toTypedIValue(input).ivalue(); +} + inline Stack toStack(const py::tuple& inputs) { return toIValue(inputs).toTuple()->elements(); } +inline TypedStack toTypedStack(const py::tuple& inputs) { + auto info = toTypedIValue(inputs); + return TypedStack(info.ivalue().toTuple()->elements(), info.type()->expect()); +} + inline IValue toIValue( py::handle obj, const TypePtr& type, diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index 45365b9..5886765 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -38,7 +38,7 @@ std::string getPythonInterpreterStackTrace() { std::shared_ptr createGraphByTracing( const py::function& func, - Stack trace_inputs, + TypedStack trace_inputs, const py::function& var_name_lookup_fn, bool force_outplace, const c10::optional& num_real_inputs) { @@ -145,7 +145,7 @@ void initPythonTracerBindings(PyObject* module) { m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); }); m.def("_tracer_enter", [](py::args trace_inputs) { - return tracer::enter(toStack(trace_inputs)); + return tracer::enter(toTypedStack(trace_inputs)); }); m.def("_tracer_exit", [](py::tuple var_outputs) { tracer::exit(toStack(var_outputs)); diff --git a/torch/csrc/jit/python_tracer.h b/torch/csrc/jit/python_tracer.h index ff68499..74b6c70 100644 --- a/torch/csrc/jit/python_tracer.h +++ b/torch/csrc/jit/python_tracer.h @@ -21,7 +21,7 @@ Node* preRecordPythonTrace( std::shared_ptr createGraphByTracing( const py::function& func, - Stack inputs, + TypedStack inputs, const py::function& var_name_lookup_fn, bool force_outplace, const c10::optional& num_real_inputs = c10::nullopt); diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index f2edf87..a579556 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1540,15 +1540,33 @@ int dictKeys(Stack& stack) { return 0; } -int dictValues(Stack& stack) { - auto dict = pop(stack).toGenericDictRef(); - std::vector values; +template +std::vector makeListForDictValues(const c10::ivalue::UnorderedMap& dict) { + std::vector values; values.reserve(dict.size()); for (auto item : dict) { - values.push_back(item.second); + values.push_back(item.second.to()); } - push(stack, IValue(values)); - return 0; + return values; +} + +Operation dictValues(const Node* n) { + auto outputType = n->output()->type()->expect(); + return [=](Stack& stack) -> int { + auto dict = pop(stack).toGenericDictRef(); + if (outputType->getElementType()->isSubtypeOf(TensorType::get())) { + push(stack, makeListForDictValues(dict)); + } else if (outputType->getElementType() == IntType::get()) { + push(stack, makeListForDictValues(dict)); + } else if (outputType->getElementType() == FloatType::get()) { + push(stack, makeListForDictValues(dict)); + } else if (outputType->getElementType() == BoolType::get()) { + push(stack, makeListForDictValues(dict)); + } else { + push(stack, makeListForDictValues(dict)); + } + return 0; + }; } int dictIndex(Stack& stack) { diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 96d1e5b..18f5426 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -939,13 +939,19 @@ void initJitScriptBindings(PyObject* module) { // this was ensured in python before calling this function std::vector parameters; gatherParametersAndBuffers(parameters, *self); - Stack inputs = toStack(input_tuple); - for (const Slot& param : parameters) { - inputs.emplace_back(param.value()); + auto typed_inputs = toTypedStack(input_tuple); + if (parameters.size() > 0) { + auto inputs = typed_inputs.stack(); + auto input_types = typed_inputs.types()->elements().vec(); + for (const Slot& param : parameters) { + inputs.emplace_back(param.value()); + input_types.push_back(incompleteInferTypeFrom(param.value())); + } + typed_inputs = TypedStack(inputs, TupleType::create(input_types)); } auto graph = tracer::createGraphByTracing( func, - inputs, + typed_inputs, var_lookup_fn, force_outplace, input_tuple.size()); diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 4680ae2..785dd89 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -201,7 +201,7 @@ Value* getNestedOutputTrace( // Start tracing, treating 'inputs' as inputs to the trace, which can be // varied on subsequent invocations of the trace. Any other variables // will be treated as constants. -std::pair, Stack> enter(Stack inputs) { +std::pair, Stack> enter(TypedStack inputs) { if (isTracing()) { AT_ERROR("Tracing can't be nested"); } @@ -234,17 +234,34 @@ std::pair, Stack> enter(Stack inputs) { elems[i] = add_input(elems[i], elem_types[i], elem_values[i]); } return Tuple::create(std::move(elems)); + } else if (auto dict_type = type->cast()) { + auto elem_pairs = input.toGenericDict()->elements(); + auto unpack_to_list = state->graph->insert(aten::values, {value}); + auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size()); + auto unpack_node = state->graph->insertNode(list_unpack); + auto elem_values = unpack_node->outputs(); + + AT_ASSERT(elem_pairs.size() == elem_values.size()); + + size_t i = 0; + for (const auto &pair : elem_pairs) { + elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]); + } + + return c10::ivalue::GenericDict::create(std::move(elem_pairs)); } else { AT_ERROR( - "Only tensors or tuples of tensors can be inputs to traced functions. Got ", - type); + "Only tensors or (possibly nested) dict or tuples of tensors can be " + "inputs to traced functions. Got ", type); } }; - for (IValue& input : inputs) { + size_t i = 0; + auto input_types = inputs.types()->elements(); + for (IValue& input : inputs.stack()) { input = add_input( - input, incompleteInferTypeFrom(input), state->graph->addInput()); + input, input_types[i++], state->graph->addInput()); } - return std::make_pair(state, inputs); + return std::make_pair(state, inputs.stack()); } // Exit a trace, treating 'outputs' as the outputs of the trace. These diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index 09ea106..cca4a95 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -65,7 +65,30 @@ TORCH_API Value* getNestedOutputTrace( const std::shared_ptr& state, const IValue& iv); -TORCH_API std::pair, Stack> enter(Stack inputs); +struct TypedStack : public std::pair +{ + using pair::pair; + + // NB: The inherited default constructor gives nullptr for |type|, + // so we provide a saner one. + TypedStack() + : pair({}, TupleType::create({})) + {} + + Stack& stack() { + return this->first; + } + TupleTypePtr& types() { + return this->second; + } + size_t size() { + auto s = stack().size(); + AT_ASSERT(s == types()->elements().size()); + return s; + } +}; + +TORCH_API std::pair, Stack> enter(TypedStack inputs); TORCH_API void exit(const Stack& outputs); diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 5a6ffcf..f8d82ee 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -656,7 +656,7 @@ def trace(func, return func executor_options = {'optimize': bool(optimize)} # Special case for common case of passing a single Tensor - if isinstance(example_inputs, torch.Tensor): + if isinstance(example_inputs, (torch.Tensor, dict)): example_inputs = (example_inputs,) # done primarily so that weird iterables fail here and not pybind11 code elif not isinstance(example_inputs, tuple): -- 2.7.4