Preserve names when converting to/from NetDef.
authorMikhail Zolotukhin <mvz@fb.com>
Fri, 22 Feb 2019 22:56:02 +0000 (14:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 22 Feb 2019 23:25:52 +0000 (15:25 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17378

Differential Revision: D14176515

Pulled By: ZolotukhinM

fbshipit-source-id: da9ea28310250ab3ca3a99cdc210fd8d1fbbc82b

test/cpp/jit/test_netdef_converter.h
torch/csrc/jit/netdef_converter.cpp

index 3dd63d6..3dba9b8 100644 (file)
@@ -9,6 +9,21 @@
 namespace torch {
 namespace jit {
 
+static caffe2::OperatorDef createOperator(
+    const std::string& name,
+    const std::vector<std::string>& inputs,
+    const std::vector<std::string>& outputs) {
+  caffe2::OperatorDef op;
+  op.set_type(name);
+  for (const auto& input : inputs) {
+    op.add_input(input);
+  }
+  for (const auto& output : outputs) {
+    op.add_output(output);
+  }
+  return op;
+}
+
 void testNetDefConverter(std::ostream& out = std::cout) {
   {
     // Check a simple net conversion back and forth.
@@ -141,6 +156,69 @@ void testNetDefConverter(std::ostream& out = std::cout) {
         n->ss(Symbol::fromQualString("attr::ss_attr")) ==
         std::vector<std::string>({"Winter", "Summer"}));
   }
+  {
+    // Check how value names are preserved in conversion. They naturally might
+    // change as IR is in SSA form, but we should try not to change names of
+    // external inputs and outputs.
+
+    // Create a simple net:
+    //  net(ext_inputs = {a, b, c})
+    //    a = foo::bar(a, b)
+    //    u = foo::baz(b, c)
+    //    x = foo::qux(u, a)
+    //    x = foo::quux(a, x)
+    //    -> (ext_outputs = {x})
+    //
+    caffe2::NetDef net;
+
+    *net.add_op() = createOperator("foo::bar", {"a", "b"}, {"a"});
+    *net.add_op() = createOperator("foo::baz", {"b", "c"}, {"u"});
+    *net.add_op() = createOperator("foo::qux", {"u", "a"}, {"x"});
+    *net.add_op() = createOperator("foo::quux", {"a", "x", "u"}, {"x"});
+    net.add_external_input("a");
+    net.add_external_input("b");
+    net.add_external_input("c");
+    net.add_external_output("x");
+
+    // Expect the following graph to be generated:
+    //    graph(%a : Tensor,
+    //          %b : Tensor,
+    //          %c : Tensor) {
+    //      %a.1 : Tensor = foo::bar(%a, %b)
+    //      %u : Tensor = foo::baz(%b, %c)
+    //      %x.1 : Tensor = foo::qux(%u, %a.1)
+    //      %x : Tensor = foo::quux(%a.1, %x.1, u)
+    //      return (%x)
+    //    }
+    Graph graph;
+    std::unordered_map<std::string, Value*> vmap;
+    convertNetDefToIR(net, &graph, &vmap);
+    AT_ASSERT(graph.inputs().size() == 3);
+    AT_ASSERT(graph.inputs()[0]->uniqueName() == "a");
+    AT_ASSERT(graph.inputs()[1]->uniqueName() == "b");
+    AT_ASSERT(graph.inputs()[2]->uniqueName() == "c");
+
+    AT_ASSERT(graph.outputs().size() == 1);
+    AT_ASSERT(graph.outputs()[0]->uniqueName() == "x");
+
+    Node* quux = graph.outputs()[0]->node();
+    Value* a0 = quux->inputs()[0];
+    Value* x0 = quux->inputs()[1];
+    Value* u = quux->inputs()[2];
+    AT_ASSERT(a0->uniqueName() != "a" && a0->uniqueNameBase() == "a");
+    AT_ASSERT(x0->uniqueName() != "x" && x0->uniqueNameBase() == "x");
+    AT_ASSERT(u->uniqueName() == "u");
+
+    // Convert back to netdef and check if the names are preserved.
+    // We still expect them to be in SSA form, but we should preserve names for
+    // external inputs and outputs.
+    caffe2::NetDef net2;
+    convertIRToNetDef(&net2, graph);
+    AT_ASSERT(net2.external_input().Get(0) == "a");
+    AT_ASSERT(net2.external_input().Get(1) == "b");
+    AT_ASSERT(net2.external_input().Get(2) == "c");
+    AT_ASSERT(net2.external_output().Get(0) == "x");
+  }
 }
 } // namespace jit
 } // namespace torch
index 4c8ca5f..8f77165 100644 (file)
@@ -77,11 +77,13 @@ void convertNetDefToIR(
     std::unordered_map<std::string, Value*>* valueMapPtr,
     const std::string& prefix) {
   std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
+  std::unordered_map<Value*, std::string> namesMap;
   valueMap.clear();
 
   for (const auto& inputName : net.external_input()) {
     AT_ASSERT(!valueMap.count(inputName));
     valueMap[inputName] = g->addInput();
+    namesMap[valueMap.at(inputName)] = inputName;
   }
 
   for (const auto& op : net.op()) {
@@ -98,7 +100,9 @@ void convertNetDefToIR(
     for (const auto& output : op.output()) {
       // If output already exists in valueMap, overwrite it. This way we will
       // have the last definition of a value named 'output' in valueMap.
-      valueMap[output] = node->outputs()[idx++];
+      Value* v = node->outputs()[idx++];
+      valueMap[output] = v;
+      namesMap[v] = output;
     }
     for (const auto& arg : op.arg()) {
       convertArg(arg, node);
@@ -108,6 +112,35 @@ void convertNetDefToIR(
   for (const auto& outputName : net.external_output()) {
     AT_ASSERT(valueMap.count(outputName));
     g->registerOutput(valueMap.at(outputName));
+    namesMap[valueMap.at(outputName)] = outputName;
+  }
+
+  // Set proper unique names for all values.
+  // We will set the names for external inputs and outputs last, so that if the
+  // names are reused, then intermediate values will be renamed and the external
+  // values will keep the original names.
+  for (Node* n : g->nodes()) {
+    for (Value* v : n->outputs()) {
+      AT_ASSERT(namesMap.count(v));
+      const std::string& name = namesMap.at(v);
+      if (Value::isValidName(name)) {
+        v->setUniqueName(name);
+      }
+    }
+  }
+  for (Value* v : g->inputs()) {
+    AT_ASSERT(namesMap.count(v));
+    const std::string& name = namesMap.at(v);
+    if (Value::isValidName(name)) {
+      v->setUniqueName(name);
+    }
+  }
+  for (Value* v : g->outputs()) {
+    AT_ASSERT(namesMap.count(v));
+    const std::string& name = namesMap.at(v);
+    if (Value::isValidName(name)) {
+      v->setUniqueName(name);
+    }
   }
 }