Implement NetDef <--> JIT IR converters. (#16967)
authorMikhail Zolotukhin <mvz@fb.com>
Thu, 14 Feb 2019 02:15:57 +0000 (18:15 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 14 Feb 2019 02:39:39 +0000 (18:39 -0800)
Summary:
Currently the converters are very straightforward, i.e. there is no code for trying to
preserve semantics, we're purely perform conversion from one format to another.

Two things that we might want to add/change:
1. Add semantic conversion as well (but probably it would be a good idea to keep
it separate as a temporary thing).
2. Make sure we don't mess with value names, as they are crucial for current
uses of NetDefs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16967

Differential Revision: D14062537

Pulled By: ZolotukhinM

fbshipit-source-id: 88b184ee7276779e5e9152b149d69857515ad98a

test/cpp/jit/gtest.cpp
test/cpp/jit/no-gtest.cpp
test/cpp/jit/test_netdef_converter.h [new file with mode: 0644]
torch/CMakeLists.txt
torch/csrc/jit/netdef_converter.cpp [new file with mode: 0644]
torch/csrc/jit/netdef_converter.h [new file with mode: 0644]

index 0434418..8e9c094 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <test/cpp/jit/test_alias_analysis.h>
 #include <test/cpp/jit/test_misc.h>
+#include <test/cpp/jit/test_netdef_converter.h>
 
 using namespace torch;
 using namespace torch::jit;
@@ -34,6 +35,8 @@ JIT_TEST(SubgraphUtils)
 JIT_TEST(AliasAnalysis)
 JIT_TEST(AliasTracker)
 
+JIT_TEST(NetDefConverter)
+
 JIT_TEST(THNNConv)
 JIT_TEST(ATenNativeBatchNorm)
 
index 9f4d4ce..72307a6 100644 (file)
@@ -1,5 +1,6 @@
 #include <test/cpp/jit/test_alias_analysis.h>
 #include <test/cpp/jit/test_misc.h>
+#include <test/cpp/jit/test_netdef_converter.h>
 
 #include <sstream>
 #include <string>
@@ -37,6 +38,7 @@ std::string runJITCPPTests() {
   testRegisterFusionCachesKernel();
   testAliasAnalysis();
   testAliasTracker();
+  testNetDefConverter(out);
   return out.str();
 }
 
diff --git a/test/cpp/jit/test_netdef_converter.h b/test/cpp/jit/test_netdef_converter.h
new file mode 100644 (file)
index 0000000..7113ebf
--- /dev/null
@@ -0,0 +1,146 @@
+#pragma once
+
+#include <torch/csrc/jit/netdef_converter.h>
+#include "test/cpp/jit/test_base.h"
+
+#include <sstream>
+#include <string>
+
+namespace torch {
+namespace jit {
+
+void testNetDefConverter(std::ostream& out = std::cout) {
+  {
+    // Check a simple net conversion back and forth.
+
+    // Create a simple graph:
+    //    graph(%0 : Tensor
+    //          %1 : Tensor) {
+    //      %2 : Tensor = aten::mul(%0, %1)
+    //      %3 : int = prim::Constant[value=1]()
+    //      %4 : Tensor = aten::add(%0, %2, %3)
+    //      return (%2, %4);
+    //    }
+    auto graph = std::make_shared<Graph>();
+    auto a = graph->addInput();
+    auto b = graph->addInput();
+    auto c = graph->insert(aten::mul, {a, b});
+    auto d = graph->insert(aten::add, {a, c});
+    graph->registerOutput(c);
+    graph->registerOutput(d);
+
+    // Convert it to netdef and check the result
+    caffe2::NetDef net;
+    convertIRToNetDef(&net, *graph);
+    AT_ASSERT(net.op().size() == 3);
+    AT_ASSERT(net.external_input().size() == 2);
+    AT_ASSERT(net.external_output().size() == 2);
+
+    const caffe2::OperatorDef& MulOp = net.op().Get(0);
+    AT_ASSERT(MulOp.input().size() == 2);
+    AT_ASSERT(MulOp.input().Get(0) == net.external_input().Get(0));
+    AT_ASSERT(MulOp.input().Get(1) == net.external_input().Get(1));
+    AT_ASSERT(MulOp.output().size() == 1);
+
+    const caffe2::OperatorDef& ConstNode = net.op().Get(1);
+    AT_ASSERT(ConstNode.input().size() == 0);
+    AT_ASSERT(ConstNode.output().size() == 1);
+    AT_ASSERT(ConstNode.arg().size() == 1);
+    AT_ASSERT(ConstNode.arg().Get(0).name() == "value");
+    AT_ASSERT(ConstNode.arg().Get(0).i() == 1);
+
+    const caffe2::OperatorDef& AddOp = net.op().Get(2);
+    AT_ASSERT(AddOp.input().size() == 3);
+    AT_ASSERT(AddOp.input().Get(0) == net.external_input().Get(0));
+    AT_ASSERT(AddOp.input().Get(1) == MulOp.output().Get(0));
+    AT_ASSERT(AddOp.input().Get(2) == ConstNode.output().Get(0));
+
+    AT_ASSERT(net.external_output().Get(0) == MulOp.output().Get(0));
+    AT_ASSERT(net.external_output().Get(1) == AddOp.output().Get(0));
+
+    // Convert NetDef back to IR and check if we get the original.
+    Graph graph2;
+    std::unordered_map<std::string, Value*> vmap;
+    convertNetDefToIR(net, &graph2, &vmap);
+
+    Node* mul = graph2.outputs()[0]->node();
+    Node* add = graph2.outputs()[1]->node();
+    AT_ASSERT(mul->kind() == c->node()->kind());
+    AT_ASSERT(add->kind() == d->node()->kind());
+    AT_ASSERT(mul->inputs()[0] == graph2.inputs()[0]);
+    AT_ASSERT(mul->inputs()[1] == graph2.inputs()[1]);
+    AT_ASSERT(add->inputs()[0] == graph2.inputs()[0]);
+    AT_ASSERT(add->inputs()[1] == graph2.outputs()[0]);
+  }
+  {
+    // Check attributes conversion
+    auto graph = std::make_shared<Graph>();
+    auto a = graph->addInput();
+    auto b = graph->addInput();
+    Node* node =
+        graph->create(Symbol::fromQualString("test::some_op"), {a, b}, 2);
+    graph->insertNode(node);
+
+    node->i_(Symbol::fromQualString("attr::i_attr"), 42);
+    node->f_(Symbol::fromQualString("attr::f_attr"), 3.0);
+    node->s_(Symbol::fromQualString("attr::s_attr"), "Hello!");
+
+    node->is_(Symbol::fromQualString("attr::is_attr"), {14, 18, 7});
+    node->fs_(Symbol::fromQualString("attr::fs_attr"), {2.72, 3.14});
+    node->ss_(Symbol::fromQualString("attr::ss_attr"), {"Winter", "Summer"});
+
+    graph->registerOutput(node->outputs()[0]);
+    graph->registerOutput(node->outputs()[1]);
+
+    // Convert it to netdef and check the result
+    caffe2::NetDef net;
+    convertIRToNetDef(&net, *graph);
+    const caffe2::OperatorDef& Op = net.op().Get(0);
+    AT_ASSERT(Op.arg().Get(0).name() == "i_attr");
+    AT_ASSERT(Op.arg().Get(0).i() == 42);
+    AT_ASSERT(Op.arg().Get(1).name() == "f_attr");
+    AT_ASSERT(Op.arg().Get(1).f() == 3.0);
+    AT_ASSERT(Op.arg().Get(2).name() == "s_attr");
+    AT_ASSERT(Op.arg().Get(2).s() == "Hello!");
+
+    AT_ASSERT(Op.arg().Get(3).name() == "is_attr");
+    AT_ASSERT(Op.arg().Get(3).ints().size() == 3);
+    AT_ASSERT(Op.arg().Get(3).ints().Get(0) == 14);
+    AT_ASSERT(Op.arg().Get(3).ints().Get(1) == 18);
+    AT_ASSERT(Op.arg().Get(3).ints().Get(2) == 7);
+
+    AT_ASSERT(Op.arg().Get(4).name() == "fs_attr");
+    AT_ASSERT(Op.arg().Get(4).floats().size() == 2);
+    AT_ASSERT(fabs(Op.arg().Get(4).floats().Get(0) - 2.72) < 0.001);
+
+    AT_ASSERT(Op.arg().Get(5).name() == "ss_attr");
+    AT_ASSERT(Op.arg().Get(5).strings().size() == 2);
+    AT_ASSERT(Op.arg().Get(5).strings().Get(1) == "Summer");
+
+    AT_ASSERT(net.external_output().Get(0) == Op.output().Get(0));
+    AT_ASSERT(net.external_output().Get(1) == Op.output().Get(1));
+
+    // Convert NetDef back to IR and check if we get the original.
+    Graph graph2;
+    std::unordered_map<std::string, Value*> vmap;
+    convertNetDefToIR(net, &graph2, &vmap);
+
+    AT_ASSERT(graph2.outputs()[0]->node() == graph2.outputs()[0]->node());
+    Node* n = graph2.outputs()[0]->node();
+    AT_ASSERT(n->i(Symbol::fromQualString("attr::i_attr")) == 42);
+    AT_ASSERT(n->f(Symbol::fromQualString("attr::f_attr")) == 3.0);
+    AT_ASSERT(n->s(Symbol::fromQualString("attr::s_attr")) == "Hello!");
+    AT_ASSERT(
+        n->is(Symbol::fromQualString("attr::is_attr")) ==
+        std::vector<long>({14, 18, 7}));
+    AT_ASSERT(
+        fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[0] - 2.72) < 0.001);
+    AT_ASSERT(
+        fabs(n->fs(Symbol::fromQualString("attr::fs_attr"))[1] - 3.14) < 0.001);
+    AT_ASSERT(
+        n->ss(Symbol::fromQualString("attr::ss_attr")) ==
+        std::vector<std::string>({"Winter", "Summer"}));
+  }
+}
+} // namespace jit
+} // namespace torch
index 7be2a60..2d2bb4f 100644 (file)
@@ -133,6 +133,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/constants.cpp
   ${TORCH_SRC_DIR}/csrc/jit/node_hashing.cpp
   ${TORCH_SRC_DIR}/csrc/jit/ir.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp
   ${TORCH_SRC_DIR}/csrc/jit/operator.cpp
   ${TORCH_SRC_DIR}/csrc/jit/caffe2_operator.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_c10_ops.cpp
diff --git a/torch/csrc/jit/netdef_converter.cpp b/torch/csrc/jit/netdef_converter.cpp
new file mode 100644 (file)
index 0000000..64e4552
--- /dev/null
@@ -0,0 +1,194 @@
+#include <torch/csrc/jit/netdef_converter.h>
+
+namespace torch {
+namespace jit {
+
+static AttributeKind getArgKind(const caffe2::Argument& arg) {
+  if (arg.has_i()) {
+    return AttributeKind::i;
+  } else if (arg.has_f()) {
+    return AttributeKind::f;
+  } else if (arg.has_s()) {
+    return AttributeKind::s;
+  } else if (arg.has_t()) {
+    return AttributeKind::t;
+  } else if (arg.has_n()) {
+    return AttributeKind::g;
+  } else if (arg.ints().size()) {
+    return AttributeKind::is;
+  } else if (arg.floats().size()) {
+    return AttributeKind::fs;
+  } else if (arg.strings().size()) {
+    return AttributeKind::ss;
+  } else if (arg.tensors().size()) {
+    return AttributeKind::ts;
+  } else if (arg.nets().size()) {
+    return AttributeKind::gs;
+  }
+  // Unknown type.
+  abort();
+}
+
+static void convertArg(const caffe2::Argument& arg, Node* node) {
+  std::string attrName = "attr::" + arg.name();
+  auto attrSymbol = Symbol::fromQualString(attrName);
+  AttributeKind kind = getArgKind(arg);
+  switch (kind) {
+    case AttributeKind::i: {
+      node->i_(attrSymbol, (long)arg.i());
+      break;
+    }
+    case AttributeKind::f: {
+      node->f_(attrSymbol, arg.f());
+      break;
+    }
+    case AttributeKind::s: {
+      node->s_(attrSymbol, arg.s());
+      break;
+    }
+    case AttributeKind::is: {
+      std::vector<long> is(arg.ints().begin(), arg.ints().end());
+      node->is_(attrSymbol, is);
+      break;
+    }
+    case AttributeKind::fs: {
+      std::vector<double> fs(arg.floats().begin(), arg.floats().end());
+      node->fs_(attrSymbol, fs);
+      break;
+    }
+    case AttributeKind::ss: {
+      std::vector<std::string> ss(arg.strings().begin(), arg.strings().end());
+      node->ss_(attrSymbol, ss);
+      break;
+    }
+    default: {
+      std::cout << "Unsupported type '" << toString(kind) << "' of attribute '"
+                << attrName << "'"
+                << " in node:" << std::endl;
+      node->dump();
+      abort();
+    }
+  }
+}
+
+void convertNetDefToIR(
+    const caffe2::NetDef& net,
+    Graph* g,
+    std::unordered_map<std::string, Value*>* valueMapPtr,
+    const std::string& prefix) {
+  std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
+  valueMap.clear();
+
+  for (const auto& inputName : net.external_input()) {
+    AT_ASSERT(!valueMap.count(inputName));
+    valueMap[inputName] = g->addInput();
+  }
+
+  for (const auto& op : net.op()) {
+    std::string name = prefix + op.type();
+    Node* node =
+        g->create(Symbol::fromQualString(name), {}, op.output().size());
+    g->insertNode(node);
+
+    for (const auto& input : op.input()) {
+      AT_ASSERT(valueMap.count(input));
+      node->addInput(valueMap[input]);
+    }
+    int idx = 0;
+    for (const auto& output : op.output()) {
+      // If output already exists in valueMap, overwrite it. This way we will
+      // have the last definition of a value named 'output' in valueMap.
+      valueMap[output] = node->outputs()[idx++];
+    }
+    for (const auto& arg : op.arg()) {
+      convertArg(arg, node);
+    }
+  }
+
+  for (const auto& outputName : net.external_output()) {
+    AT_ASSERT(valueMap.count(outputName));
+    g->registerOutput(valueMap.at(outputName));
+  }
+}
+
+static void convertAttrToCaffe2Arg(
+    const Node* node,
+    const Symbol& name,
+    caffe2::Argument* arg) {
+  arg->set_name(name.toUnqualString());
+  switch (node->kindOf(name)) {
+    case AttributeKind::i: {
+      arg->set_i(node->i(name));
+      break;
+    }
+    case AttributeKind::f: {
+      arg->set_f(node->f(name));
+      break;
+    }
+    case AttributeKind::s: {
+      arg->set_s(node->s(name));
+      break;
+    }
+    case AttributeKind::is: {
+      for (long i : node->is(name)) {
+        arg->add_ints(i);
+      }
+      break;
+    }
+    case AttributeKind::fs: {
+      for (double f : node->fs(name)) {
+        arg->add_floats(f);
+      }
+      break;
+    }
+    case AttributeKind::ss: {
+      for (const std::string& s : node->ss(name)) {
+        arg->add_strings(s);
+      }
+      break;
+    }
+    default: {
+      std::cout << "Unsupported type '" << toString(node->kindOf(name))
+                << "' of attribute '" << name.toUnqualString() << "'"
+                << " in node:" << std::endl;
+      node->dump();
+      abort();
+    }
+  }
+}
+
+static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) {
+  caffe2::OperatorDef op;
+  op.set_type(node->kind().toQualString());
+  for (const Value* input : node->inputs()) {
+    op.add_input(input->uniqueName());
+  }
+  for (const Value* output : node->outputs()) {
+    op.add_output(output->uniqueName());
+  }
+  std::vector<Symbol> names = node->attributeNames();
+  for (const Symbol& name : names) {
+    caffe2::Argument* arg = op.add_arg();
+    convertAttrToCaffe2Arg(node, name, arg);
+  }
+  *net->add_op() = op;
+}
+
+void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) {
+  net->mutable_op()->Clear();
+
+  for (const Value* value : g.inputs()) {
+    net->add_external_input(value->uniqueName());
+  }
+
+  for (const Node* node : g.nodes()) {
+    convertNodeToCaffe2Op(node, net);
+  }
+
+  for (const Value* value : g.outputs()) {
+    net->add_external_output(value->uniqueName());
+  }
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/netdef_converter.h b/torch/csrc/jit/netdef_converter.h
new file mode 100644 (file)
index 0000000..d2effb2
--- /dev/null
@@ -0,0 +1,38 @@
+#pragma once
+#include <caffe2/proto/caffe2_pb.h>
+#include <torch/csrc/jit/ir.h>
+#include <unordered_map>
+
+namespace torch {
+namespace jit {
+
+/** \brief Convert a caffe2 NetDef to PyTorch IR.
+ *
+ * The NetDef \p net is converted and the result is stored in the
+ * torch::jit::Graph \p graph. The function also records name->value map in \p
+ * valueMapPtr. If the original net had several values with the same name, the
+ * map will contain the value for the last definition.
+ * \p Prefix can be used for appending some string to every operator name (e.g.
+ * we can add "caffe2::").
+ */
+void convertNetDefToIR(
+    const caffe2::NetDef& net,
+    Graph* graph,
+    std::unordered_map<std::string, Value*>* valueMapPtr,
+    const std::string& prefix = "");
+
+/** \brief Convert PyTorch IR \p graph to Caffe2 NetDef \p net.
+ *
+ * Note: for constant nodes (prim::Const) we generate a separate op in the net,
+ * which might or might not be what we want. The idea here is that eventually
+ * both formats will converge to PyTorch IR, so for now we try to keep as close
+ * to it as possible. For short-term applications we might add a separate pass
+ * that would fold such const-nodes into their users.
+ *
+ * TODO: We might need to do a better job at preserving names of the variables,
+ * especially external_inputs/external_outputs.
+ */
+void convertIRToNetDef(caffe2::NetDef* net, const Graph& graph);
+
+} // namespace jit
+} // namespace torch