}
}
return static_cast<TypePtr>(TupleType::create(elements));
+ } else if (t1->cast<DictType>() && t2->cast<DictType>()) {
+ auto dict1 = t1->cast<DictType>();
+ auto dict2 = t2->cast<DictType>();
+
+ auto unified_key = unifyTypes(dict1->getKeyType(), dict2->getKeyType());
+ auto unshaped_value1 = unshapedType(dict1->getValueType());
+ auto unshaped_value2 = unshapedType(dict2->getValueType());
+ auto unified_value = tryEitherIsTheSuperType(unshaped_value1, unshaped_value2);
+ if (!unified_key || !unified_value) {
+ return c10::nullopt;
+ }
+ return DictType::create(*unified_key, *unified_value);
}
return c10::nullopt;
std::shared_ptr<Graph> trace(
const ADTestSpec& test,
const variable_list& vars_in) {
+ Stack input_vars = fmap<IValue>(vars_in);
+ std::vector<TypePtr> input_types;
+ input_types.reserve(input_vars.size());
+ for (auto i = 0; i < input_vars.size(); i++) {
+ input_types.push_back(TensorType::get());
+ }
+ auto input_typeptr = TupleType::create(std::move(input_types));
std::shared_ptr<tracer::TracingState> state;
Stack trace_stack_in;
- std::tie(state, trace_stack_in) = tracer::enter(fmap<IValue>(vars_in));
+ std::tie(state, trace_stack_in) =
+ tracer::enter(tracer::TypedStack(input_vars, input_typeptr));
variable_list trace_vars_in = fmap(
trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); });
auto trace_vars_out = test(trace_vars_in);
from itertools import product, chain
import torch.jit.frontend
from torch.autograd import Variable, Function
+from torch.autograd.function import _nested_map
from torch.onnx import OperatorExportTypes
from torch._six import inf, PY2, builtins, StringIO
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
freeze_rng_state, set_rng_seed, slowTest
from common_nn import module_tests, new_module_tests, criterion_tests
from textwrap import dedent
-from functools import wraps
+from functools import wraps, reduce
import os
import io
import sys
if input_tensors is None:
input_tensors = reference_tensors
+ def do_input_map(fn, input):
+ return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
+
+ def flatten_inputs(inputs):
+ def input_reduce(input, fn, acc):
+ if isinstance(input, torch.Tensor):
+ fn(input, acc)
+ elif isinstance(input, dict):
+ reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
+ else:
+ reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
+ return acc
+ return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
+
nograd_inputs = reference_tensors
if inputs_require_grads:
- recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
+ recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
+ flattened_recording_inputs = flatten_inputs(recording_inputs)
else:
recording_inputs = reference_tensors
# test single grad case
outputs = func(*recording_inputs)
if inputs_require_grads:
- grads = torch.autograd.grad(allSum(outputs), recording_inputs,
+ grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
allow_unused=allow_unused)
outputs_ge = ge(*recording_inputs)
if inputs_require_grads:
- grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
+ grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
if inputs_require_grads:
outputs = func(*recording_inputs)
l1 = allSum(outputs)
if inputs_require_grads:
- grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
+ grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
allow_unused=allow_unused)
if inputs_require_grads:
l2 = (allSum(grads) * l1)
- grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
+ grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
if inputs_require_grads:
- recording_inputs = [Variable(t, requires_grad=True)
- for t in reference_tensors]
+ recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
+ flattened_recording_inputs = flatten_inputs(recording_inputs)
outputs_ge = ge(*recording_inputs)
l1_ge = allSum(outputs_ge)
if inputs_require_grads:
grads_ge = torch.autograd.grad(
- l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
+ l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
if inputs_require_grads:
l2_ge = (allSum(grads_ge) * l1_ge)
- grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
+ grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
if inputs_require_grads:
self.checkTrace(fn, (torch.randn(2, 2),))
- # TODO: implement
- @unittest.expectedFailure
def test_input_flatten(self):
"""Check that inputs to traced functions are flattened"""
inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
self.checkTrace(fn, inputs)
+ def test_input_dict_empty(self):
+ def test(d):
+ pass
+
+ with self.assertRaises(RuntimeError):
+ self.checkTrace(test, {})
+
+ def test_input_dict_flattens(self):
+ class Test(torch.nn.Module):
+ def forward(self, d):
+ return d['x'] + d['y']
+
+ inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
+ module = torch.jit.trace(Test(), inputs)
+ FileCheck().check('aten::values').check('prim::ListUnpack').run(str(module.graph))
+
+ def test_input_dict_flattens_recursive(self):
+ class Test(torch.nn.Module):
+ def forward(self, d):
+ # Use both to avoid getting optimized away
+ a = d['x'][0]
+ b, c = d['y']
+ return a + b
+
+ inputs = {'x': (torch.rand(2, 2), torch.rand(2, 2)), 'y': (torch.ones(1, 1), torch.ones(2, 1))}
+ module = torch.jit.trace(Test(), inputs)
+ FileCheck().check('aten::values') \
+ .check('prim::ListUnpack') \
+ .check_count('prim::TupleUnpack', 2) \
+ .run(str(module.graph))
+
+ def test_input_dict_checkTrace_mut(self):
+ def test(d):
+ d['x'].tanh_()
+ return d['x']
+ inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)}
+ self.checkTrace(test, (inputs,), inputs_require_grads=False)
+
+ def test_input_dict_unify(self):
+ def test(d):
+ return d['int'], d['float']
+ inputs = {'int': torch.ones((2, 2), dtype=torch.int32),
+ 'float': torch.ones((2, 2), dtype=torch.float32)}
+ self.checkTrace(test, (inputs,), inputs_require_grads=False)
+
+ def test_input_tuple_of_dicts(self):
+ def test(t):
+ d = t[0]
+ return d['x']['y']
+ inputs = {'x': {'y': torch.rand(2, 3)}}
+ self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
+
+ def test_input_dict_of_dicts(self):
+ def test(d):
+ return d['x']['y']
+ nested_input = {'y': torch.rand(2, 3)}
+ unified_nested = {'y': torch.rand(3, 2)}
+ inputs = {'x': nested_input, 'force_unify': unified_nested}
+ self.checkTrace(test, (inputs,), allow_unused=True)
+
# TODO: adapt to a GraphExecutor test
@unittest.skip("Need to instrument GraphExecutors a bit more")
def test_flags(self):
return None
elif isinstance(obj, (list, tuple)):
return type(obj)(_map(x) for x in obj)
+ elif isinstance(obj, dict):
+ return {x : _map(obj[x]) for x in obj}
else:
raise ValueError("Auto nesting doesn't know how to process "
"an input object of type " + torch.typename(obj) +
for o in obj:
for var in _iter(o):
yield var
+ elif isinstance(obj, dict):
+ # We only accept primitive key types, so we needn't inspect them
+ for o in obj.values():
+ for var in _iter(o):
+ yield var
elif allow_unknown:
yield obj
else:
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.
-inline IValue toIValue(py::handle input) {
+using tracer::TypedStack;
+struct TypedIValue : public std::pair<IValue, TypePtr> {
+ using pair::pair;
+
+ IValue& ivalue() {
+ return this->first;
+ }
+ TypePtr& type() {
+ return this->second;
+ }
+};
+
+inline TypedIValue toDictKeyIValue(py::handle key) {
+ if (py::isinstance<py::str>(key)) {
+ return TypedIValue(ConstantString::create(py::cast<std::string>(key)),
+ StringType::create());
+ } else if (PyLong_Check(key.ptr())) {
+ return TypedIValue(py::cast<int64_t>(key), IntType::create());
+ } else if (PyFloat_Check(key.ptr())) {
+ return TypedIValue(py::cast<double>(key), FloatType::create());
+ } else {
+ AT_ERROR("Dictionary inputs may only have string, int, or float keys");
+ }
+}
+
+inline TypedIValue toTypedIValue(py::handle input) {
if (THPVariable_Check(input.ptr())) {
auto ten = py::cast<at::Tensor>(input);
if (ten.is_sparse()) {
AT_ERROR("sparse tensors not supported");
}
- return ten;
+ return TypedIValue(ten, CompleteTensorType::create(ten));
} else if (six::isTuple(input)) {
py::tuple input_tuple = py::cast<py::tuple>(input);
Stack s;
+ std::vector<TypePtr> t;
s.reserve(input_tuple.size());
+ t.reserve(input_tuple.size());
for (py::handle elem : input_tuple) {
- s.push_back(toIValue(elem));
+ auto info = toTypedIValue(elem);
+ s.push_back(info.first);
+ t.push_back(info.second);
+ }
+ return TypedIValue(Tuple::create(s), TupleType::create(t));
+ } else if (PyDict_Check(input.ptr())) {
+ // Check to make sure we can generate useful input/output types
+ auto dict = py::cast<py::dict>(input);
+ at::ivalue::UnorderedMap elems;
+
+ size_t len = py::len(dict);
+ if (!len) {
+ AT_ERROR("Dictionary inputs must have entries.");
}
- return Tuple::create(s);
+ elems.reserve(len);
+
+ TypePtr keyType = nullptr;
+ TypePtr valueType = nullptr;
+ for (auto entry : dict) {
+ auto keyInfo = toDictKeyIValue(entry.first);
+ auto valInfo = toTypedIValue(entry.second);
+ if (!keyType) {
+ keyType = keyInfo.second;
+ valueType = valInfo.second;
+ } else {
+ auto unifiedKey = unifyTypes(keyType, keyInfo.second);
+ auto unifiedValue = unifyTypes(valueType, valInfo.second);
+ if (!unifiedKey || !unifiedValue) {
+ AT_ERROR("Dictionary inputs to traced functions must have consistent type");
+ }
+ keyType = *unifiedKey;
+ valueType = *unifiedValue;
+ }
+ elems.insert(std::make_pair(keyInfo.first, valInfo.first));
+ }
+ return TypedIValue(at::ivalue::GenericDict::create(std::move(elems)),
+ DictType::create(keyType, valueType));
} else {
throw std::runtime_error(c10::str(
- "Only tensors and (possibly nested) tuples of tensors are supported ",
+ "Only tensors and (possibly nested) tuples of tensors or dicts are supported ",
"as inputs or outputs of traced functions",
", but instead got value of type ",
py::str(input.get_type().attr("__name__")),
}
}
+inline IValue toIValue(py::handle input) {
+ return toTypedIValue(input).ivalue();
+}
+
inline Stack toStack(const py::tuple& inputs) {
return toIValue(inputs).toTuple()->elements();
}
+inline TypedStack toTypedStack(const py::tuple& inputs) {
+ auto info = toTypedIValue(inputs);
+ return TypedStack(info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>());
+}
+
inline IValue toIValue(
py::handle obj,
const TypePtr& type,
std::shared_ptr<torch::jit::Graph> createGraphByTracing(
const py::function& func,
- Stack trace_inputs,
+ TypedStack trace_inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
const c10::optional<size_t>& num_real_inputs) {
m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
m.def("_tracer_enter", [](py::args trace_inputs) {
- return tracer::enter(toStack(trace_inputs));
+ return tracer::enter(toTypedStack(trace_inputs));
});
m.def("_tracer_exit", [](py::tuple var_outputs) {
tracer::exit(toStack(var_outputs));
std::shared_ptr<Graph> createGraphByTracing(
const py::function& func,
- Stack inputs,
+ TypedStack inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
const c10::optional<size_t>& num_real_inputs = c10::nullopt);
return 0;
}
-int dictValues(Stack& stack) {
- auto dict = pop(stack).toGenericDictRef();
- std::vector<IValue> values;
+template <typename Elem>
+std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) {
+ std::vector<Elem> values;
values.reserve(dict.size());
for (auto item : dict) {
- values.push_back(item.second);
+ values.push_back(item.second.to<Elem>());
}
- push(stack, IValue(values));
- return 0;
+ return values;
+}
+
+Operation dictValues(const Node* n) {
+ auto outputType = n->output()->type()->expect<ListType>();
+ return [=](Stack& stack) -> int {
+ auto dict = pop(stack).toGenericDictRef();
+ if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
+ push(stack, makeListForDictValues<at::Tensor>(dict));
+ } else if (outputType->getElementType() == IntType::get()) {
+ push(stack, makeListForDictValues<int64_t>(dict));
+ } else if (outputType->getElementType() == FloatType::get()) {
+ push(stack, makeListForDictValues<double>(dict));
+ } else if (outputType->getElementType() == BoolType::get()) {
+ push(stack, makeListForDictValues<bool>(dict));
+ } else {
+ push(stack, makeListForDictValues<IValue>(dict));
+ }
+ return 0;
+ };
}
int dictIndex(Stack& stack) {
// this was ensured in python before calling this function
std::vector<Slot> parameters;
gatherParametersAndBuffers(parameters, *self);
- Stack inputs = toStack(input_tuple);
- for (const Slot& param : parameters) {
- inputs.emplace_back(param.value());
+ auto typed_inputs = toTypedStack(input_tuple);
+ if (parameters.size() > 0) {
+ auto inputs = typed_inputs.stack();
+ auto input_types = typed_inputs.types()->elements().vec();
+ for (const Slot& param : parameters) {
+ inputs.emplace_back(param.value());
+ input_types.push_back(incompleteInferTypeFrom(param.value()));
+ }
+ typed_inputs = TypedStack(inputs, TupleType::create(input_types));
}
auto graph = tracer::createGraphByTracing(
func,
- inputs,
+ typed_inputs,
var_lookup_fn,
force_outplace,
input_tuple.size());
// Start tracing, treating 'inputs' as inputs to the trace, which can be
// varied on subsequent invocations of the trace. Any other variables
// will be treated as constants.
-std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs) {
+std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs) {
if (isTracing()) {
AT_ERROR("Tracing can't be nested");
}
elems[i] = add_input(elems[i], elem_types[i], elem_values[i]);
}
return Tuple::create(std::move(elems));
+ } else if (auto dict_type = type->cast<DictType>()) {
+ auto elem_pairs = input.toGenericDict()->elements();
+ auto unpack_to_list = state->graph->insert(aten::values, {value});
+ auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size());
+ auto unpack_node = state->graph->insertNode(list_unpack);
+ auto elem_values = unpack_node->outputs();
+
+ AT_ASSERT(elem_pairs.size() == elem_values.size());
+
+ size_t i = 0;
+ for (const auto &pair : elem_pairs) {
+ elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]);
+ }
+
+ return c10::ivalue::GenericDict::create(std::move(elem_pairs));
} else {
AT_ERROR(
- "Only tensors or tuples of tensors can be inputs to traced functions. Got ",
- type);
+ "Only tensors or (possibly nested) dict or tuples of tensors can be "
+ "inputs to traced functions. Got ", type);
}
};
- for (IValue& input : inputs) {
+ size_t i = 0;
+ auto input_types = inputs.types()->elements();
+ for (IValue& input : inputs.stack()) {
input = add_input(
- input, incompleteInferTypeFrom(input), state->graph->addInput());
+ input, input_types[i++], state->graph->addInput());
}
- return std::make_pair(state, inputs);
+ return std::make_pair(state, inputs.stack());
}
// Exit a trace, treating 'outputs' as the outputs of the trace. These
const std::shared_ptr<TracingState>& state,
const IValue& iv);
-TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(Stack inputs);
+struct TypedStack : public std::pair<Stack, TupleTypePtr>
+{
+ using pair::pair;
+
+ // NB: The inherited default constructor gives nullptr for |type|,
+ // so we provide a saner one.
+ TypedStack()
+ : pair({}, TupleType::create({}))
+ {}
+
+ Stack& stack() {
+ return this->first;
+ }
+ TupleTypePtr& types() {
+ return this->second;
+ }
+ size_t size() {
+ auto s = stack().size();
+ AT_ASSERT(s == types()->elements().size());
+ return s;
+ }
+};
+
+TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs);
TORCH_API void exit(const Stack& outputs);
return func
executor_options = {'optimize': bool(optimize)}
# Special case for common case of passing a single Tensor
- if isinstance(example_inputs, torch.Tensor):
+ if isinstance(example_inputs, (torch.Tensor, dict)):
example_inputs = (example_inputs,)
# done primarily so that weird iterables fail here and not pybind11 code
elif not isinstance(example_inputs, tuple):