JIT IR - Add option to remove prefix string when converting from JIT IR to NetDef...
authorDuc Ngo <duc@fb.com>
Tue, 12 Mar 2019 23:54:47 +0000 (16:54 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Mar 2019 00:02:26 +0000 (17:02 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17931

When converting from NetDef to IR and back, the prefix string should be removed so the operator types are preserved in caffe2.

Reviewed By: ZolotukhinM

Differential Revision: D14425954

fbshipit-source-id: 2807e7337b0f804f126970768b1250a4a8c5f35c

test/cpp/jit/test_netdef_converter.h
torch/csrc/jit/netdef_converter.cpp
torch/csrc/jit/netdef_converter.h

index 0ac7d92..9bec094 100644 (file)
@@ -219,6 +219,35 @@ void testNetDefConverter() {
     AT_ASSERT(net2.external_input().Get(2) == "c");
     AT_ASSERT(net2.external_output().Get(0) == "x");
   }
+
+  {
+    // Test that prefix is removed when converting from NetDef to IR and back.
+    caffe2::NetDef net;
+    *net.add_op() = createOperator("MatMul", {"a", "b"}, {"c"});
+    net.add_external_input("a");
+    net.add_external_input("b");
+    net.add_external_output("c");
+    Graph graph;
+    std::unordered_map<std::string, Value*> vmap;
+    convertNetDefToIR(net, &graph, &vmap, "caffe2::");
+
+    caffe2::NetDef net2;
+    convertIRToNetDef(&net2, graph, "caffe2::");
+    // The conversion should remove the prefix if it maches.
+    AT_ASSERT(net2.op(0).type() == "MatMul");
+
+    caffe2::NetDef net3;
+    convertIRToNetDef(&net3, graph, "foo::");
+    // The conversion should still work if the prefix does not match.
+    AT_ASSERT(net3.op(0).type() == "caffe2::MatMul");
+
+    // Prefix shouldn't affect blob names.
+    AT_ASSERT(net2.op(0).input(0) == "a");
+    AT_ASSERT(net2.external_input(0) == "a");
+    AT_ASSERT(net2.external_output(0) == "c");
+    AT_ASSERT(net3.external_input(0) == "a");
+  }
 }
+
 } // namespace jit
 } // namespace torch
index 8f77165..2e5dce3 100644 (file)
@@ -190,9 +190,19 @@ static void convertAttrToCaffe2Arg(
   }
 }
 
-static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) {
+static std::string removePrefixIfNeeded(const std::string& name,
+    const std::string& prefix) {
+  if (!name.compare(0, prefix.size(), prefix)) {
+    return name.substr(prefix.size());
+  } else {
+    return name;
+  }
+}
+
+static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net,
+  const std::string& prefix = "") {
   caffe2::OperatorDef op;
-  op.set_type(node->kind().toQualString());
+  op.set_type(removePrefixIfNeeded(node->kind().toQualString(), prefix));
   for (const Value* input : node->inputs()) {
     op.add_input(input->uniqueName());
   }
@@ -207,7 +217,8 @@ static void convertNodeToCaffe2Op(const Node* node, caffe2::NetDef* net) {
   *net->add_op() = op;
 }
 
-void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) {
+void convertIRToNetDef(caffe2::NetDef* net, const Graph& g,
+  const std::string& prefix) {
   net->mutable_op()->Clear();
 
   for (const Value* value : g.inputs()) {
@@ -215,7 +226,7 @@ void convertIRToNetDef(caffe2::NetDef* net, const Graph& g) {
   }
 
   for (const Node* node : g.nodes()) {
-    convertNodeToCaffe2Op(node, net);
+    convertNodeToCaffe2Op(node, net, prefix);
   }
 
   for (const Value* value : g.outputs()) {
index d2effb2..03fd014 100644 (file)
@@ -28,11 +28,16 @@ void convertNetDefToIR(
  * both formats will converge to PyTorch IR, so for now we try to keep as close
  * to it as possible. For short-term applications we might add a separate pass
  * that would fold such const-nodes into their users.
+ * \p If Prefix is specified, the prefix will be removed from operator name when
+ * converting from IR to NetDef.
  *
  * TODO: We might need to do a better job at preserving names of the variables,
  * especially external_inputs/external_outputs.
  */
-void convertIRToNetDef(caffe2::NetDef* net, const Graph& graph);
+void convertIRToNetDef(
+    caffe2::NetDef* net,
+    const Graph& graph,
+    const std::string& prefix = "");
 
 } // namespace jit
 } // namespace torch