JIT IR - Make valueMapPtr optional in convertNetDefToIR (#17942)
authorDuc Ngo <duc@fb.com>
Thu, 14 Mar 2019 19:16:12 +0000 (12:16 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 14 Mar 2019 19:22:49 +0000 (12:22 -0700)
Summary:
Make valueMapPtr optional in convertNetDefToIR, and add tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17942

Differential Revision: D14429687

Pulled By: duc0

fbshipit-source-id: 3a5a72bbb5acc1bfd7144a987688c599016fbf7a

test/cpp/jit/test_netdef_converter.h
torch/csrc/jit/netdef_converter.cpp
torch/csrc/jit/netdef_converter.h

index 9bec094..fdf31b8 100644 (file)
@@ -230,6 +230,8 @@ void testNetDefConverter() {
     Graph graph;
     std::unordered_map<std::string, Value*> vmap;
     convertNetDefToIR(net, &graph, &vmap, "caffe2::");
+    // Sanity check that value map is returned and it works.
+    AT_ASSERT(vmap["a"]->uniqueName() == "a");
 
     caffe2::NetDef net2;
     convertIRToNetDef(&net2, graph, "caffe2::");
@@ -246,6 +248,10 @@ void testNetDefConverter() {
     AT_ASSERT(net2.external_input(0) == "a");
     AT_ASSERT(net2.external_output(0) == "c");
     AT_ASSERT(net3.external_input(0) == "a");
+
+    Graph graph2;
+    // Test that conversion works without passing in a valueMap.
+    convertNetDefToIR(net, &graph2, nullptr, "caffe2::");
   }
 }
 
index 2e5dce3..78aff33 100644 (file)
@@ -76,6 +76,12 @@ void convertNetDefToIR(
     Graph* g,
     std::unordered_map<std::string, Value*>* valueMapPtr,
     const std::string& prefix) {
+  if (!valueMapPtr) {
+    std::unordered_map<std::string, Value*> localValueMap;
+    // If valueMapPtr is null, we just use a local map since we don't need
+    // to return the valueMap to the caller.
+    return convertNetDefToIR(net, g, &localValueMap, prefix);
+  }
   std::unordered_map<std::string, Value*>& valueMap = *valueMapPtr;
   std::unordered_map<Value*, std::string> namesMap;
   valueMap.clear();
index 03fd014..7d82433 100644 (file)
@@ -11,14 +11,14 @@ namespace jit {
  * The NetDef \p net is converted and the result is stored in the
  * torch::jit::Graph \p graph. The function also records name->value map in \p
  * valueMapPtr. If the original net had several values with the same name, the
- * map will contain the value for the last definition.
+ * map will contain the value for the last definition. valueMapPtr is optional.
  * \p Prefix can be used for appending some string to every operator name (e.g.
  * we can add "caffe2::").
  */
 void convertNetDefToIR(
     const caffe2::NetDef& net,
     Graph* graph,
-    std::unordered_map<std::string, Value*>* valueMapPtr,
+    std::unordered_map<std::string, Value*>* valueMapPtr = nullptr,
     const std::string& prefix = "");
 
 /** \brief Convert PyTorch IR \p graph to Caffe2 NetDef \p net.