From 6d744f8fbfe6ef1c90534893d4ca2b2caf8a0df4 Mon Sep 17 00:00:00 2001 From: Mikhail Zolotukhin Date: Fri, 22 Feb 2019 14:56:02 -0800 Subject: [PATCH] Preserve names when converting to/from NetDef. Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17378 Differential Revision: D14176515 Pulled By: ZolotukhinM fbshipit-source-id: da9ea28310250ab3ca3a99cdc210fd8d1fbbc82b --- test/cpp/jit/test_netdef_converter.h | 78 ++++++++++++++++++++++++++++++++++++ torch/csrc/jit/netdef_converter.cpp | 35 +++++++++++++++- 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/test/cpp/jit/test_netdef_converter.h b/test/cpp/jit/test_netdef_converter.h index 3dd63d6..3dba9b8 100644 --- a/test/cpp/jit/test_netdef_converter.h +++ b/test/cpp/jit/test_netdef_converter.h @@ -9,6 +9,21 @@ namespace torch { namespace jit { +static caffe2::OperatorDef createOperator( + const std::string& name, + const std::vector& inputs, + const std::vector& outputs) { + caffe2::OperatorDef op; + op.set_type(name); + for (const auto& input : inputs) { + op.add_input(input); + } + for (const auto& output : outputs) { + op.add_output(output); + } + return op; +} + void testNetDefConverter(std::ostream& out = std::cout) { { // Check a simple net conversion back and forth. @@ -141,6 +156,69 @@ void testNetDefConverter(std::ostream& out = std::cout) { n->ss(Symbol::fromQualString("attr::ss_attr")) == std::vector({"Winter", "Summer"})); } + { + // Check how value names are preserved in conversion. They naturally might + // change as IR is in SSA form, but we should try not to change names of + // external inputs and outputs. + + // Create a simple net: + // net(ext_inputs = {a, b, c}) + // a = foo::bar(a, b) + // u = foo::baz(b, c) + // x = foo::qux(u, a) + // x = foo::quux(a, x) + // -> (ext_outputs = {x}) + // + caffe2::NetDef net; + + *net.add_op() = createOperator("foo::bar", {"a", "b"}, {"a"}); + *net.add_op() = createOperator("foo::baz", {"b", "c"}, {"u"}); + *net.add_op() = createOperator("foo::qux", {"u", "a"}, {"x"}); + *net.add_op() = createOperator("foo::quux", {"a", "x", "u"}, {"x"}); + net.add_external_input("a"); + net.add_external_input("b"); + net.add_external_input("c"); + net.add_external_output("x"); + + // Expect the following graph to be generated: + // graph(%a : Tensor, + // %b : Tensor, + // %c : Tensor) { + // %a.1 : Tensor = foo::bar(%a, %b) + // %u : Tensor = foo::baz(%b, %c) + // %x.1 : Tensor = foo::qux(%u, %a.1) + // %x : Tensor = foo::quux(%a.1, %x.1, u) + // return (%x) + // } + Graph graph; + std::unordered_map vmap; + convertNetDefToIR(net, &graph, &vmap); + AT_ASSERT(graph.inputs().size() == 3); + AT_ASSERT(graph.inputs()[0]->uniqueName() == "a"); + AT_ASSERT(graph.inputs()[1]->uniqueName() == "b"); + AT_ASSERT(graph.inputs()[2]->uniqueName() == "c"); + + AT_ASSERT(graph.outputs().size() == 1); + AT_ASSERT(graph.outputs()[0]->uniqueName() == "x"); + + Node* quux = graph.outputs()[0]->node(); + Value* a0 = quux->inputs()[0]; + Value* x0 = quux->inputs()[1]; + Value* u = quux->inputs()[2]; + AT_ASSERT(a0->uniqueName() != "a" && a0->uniqueNameBase() == "a"); + AT_ASSERT(x0->uniqueName() != "x" && x0->uniqueNameBase() == "x"); + AT_ASSERT(u->uniqueName() == "u"); + + // Convert back to netdef and check if the names are preserved. + // We still expect them to be in SSA form, but we should preserve names for + // external inputs and outputs. + caffe2::NetDef net2; + convertIRToNetDef(&net2, graph); + AT_ASSERT(net2.external_input().Get(0) == "a"); + AT_ASSERT(net2.external_input().Get(1) == "b"); + AT_ASSERT(net2.external_input().Get(2) == "c"); + AT_ASSERT(net2.external_output().Get(0) == "x"); + } } } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/netdef_converter.cpp b/torch/csrc/jit/netdef_converter.cpp index 4c8ca5f..8f77165 100644 --- a/torch/csrc/jit/netdef_converter.cpp +++ b/torch/csrc/jit/netdef_converter.cpp @@ -77,11 +77,13 @@ void convertNetDefToIR( std::unordered_map* valueMapPtr, const std::string& prefix) { std::unordered_map& valueMap = *valueMapPtr; + std::unordered_map namesMap; valueMap.clear(); for (const auto& inputName : net.external_input()) { AT_ASSERT(!valueMap.count(inputName)); valueMap[inputName] = g->addInput(); + namesMap[valueMap.at(inputName)] = inputName; } for (const auto& op : net.op()) { @@ -98,7 +100,9 @@ void convertNetDefToIR( for (const auto& output : op.output()) { // If output already exists in valueMap, overwrite it. This way we will // have the last definition of a value named 'output' in valueMap. - valueMap[output] = node->outputs()[idx++]; + Value* v = node->outputs()[idx++]; + valueMap[output] = v; + namesMap[v] = output; } for (const auto& arg : op.arg()) { convertArg(arg, node); @@ -108,6 +112,35 @@ void convertNetDefToIR( for (const auto& outputName : net.external_output()) { AT_ASSERT(valueMap.count(outputName)); g->registerOutput(valueMap.at(outputName)); + namesMap[valueMap.at(outputName)] = outputName; + } + + // Set proper unique names for all values. + // We will set the names for external inputs and outputs last, so that if the + // names are reused, then intermediate values will be renamed and the external + // values will keep the original names. + for (Node* n : g->nodes()) { + for (Value* v : n->outputs()) { + AT_ASSERT(namesMap.count(v)); + const std::string& name = namesMap.at(v); + if (Value::isValidName(name)) { + v->setUniqueName(name); + } + } + } + for (Value* v : g->inputs()) { + AT_ASSERT(namesMap.count(v)); + const std::string& name = namesMap.at(v); + if (Value::isValidName(name)) { + v->setUniqueName(name); + } + } + for (Value* v : g->outputs()) { + AT_ASSERT(namesMap.count(v)); + const std::string& name = namesMap.at(v); + if (Value::isValidName(name)) { + v->setUniqueName(name); + } } } -- 2.7.4