IRParser: optionally create name->value map of the parsed IR. (#19551)
authorMikhail Zolotukhin <mvz@fb.com>
Mon, 22 Apr 2019 23:02:40 +0000 (16:02 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 22 Apr 2019 23:09:05 +0000 (16:09 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19551
ghimport-source-id: e666e3c00786a3b1c747f2dd6e85a48a63bdd69d

Differential Revision: D15028056

Pulled By: ZolotukhinM

fbshipit-source-id: 37e08d6df1d43513748ecfdd8549738eac7ec24e

test/cpp/jit/test_irparser.h
torch/csrc/jit/irparser.cpp
torch/csrc/jit/irparser.h

index af8d28a..5e48164 100644 (file)
@@ -43,6 +43,7 @@ static void checkRoundtrip(const std::string& s) {
 void testIRParser() {
   {
     auto graph = std::make_shared<Graph>();
+    std::unordered_map<std::string, Value*> vmap;
     script::parseIR(
         R"IR(
 graph(%0 : Tensor, %1 : Tensor):
@@ -50,7 +51,8 @@ graph(%0 : Tensor, %1 : Tensor):
   %res, %3 = foo::mul(%0, %2)
   %x, %y = foo::combine(%res, %2, %3)
   return (%x, %y, %res))IR",
-        &*graph);
+        &*graph,
+        vmap);
 
     AT_ASSERT(graph->inputs().size() == 2);
     AT_ASSERT(graph->outputs().size() == 3);
@@ -59,10 +61,17 @@ graph(%0 : Tensor, %1 : Tensor):
     Value* res = graph->outputs()[2];
     Value* t0 = graph->inputs()[0];
     Value* t1 = graph->inputs()[1];
+    AT_ASSERT(vmap["x"] == x);
+    AT_ASSERT(vmap["y"] == y);
+    AT_ASSERT(vmap["res"] == res);
+    AT_ASSERT(vmap["0"] == t0);
+    AT_ASSERT(vmap["1"] == t1);
     AT_ASSERT(x->node() == y->node());
     Node* comb = x->node();
     Value* t2 = comb->inputs()[1];
     Value* t3 = comb->inputs()[2];
+    AT_ASSERT(vmap["2"] == t2);
+    AT_ASSERT(vmap["3"] == t3);
     AT_ASSERT(comb->kind().toQualString() == std::string("foo::combine"));
     AT_ASSERT(comb->outputs() == std::vector<Value*>({x, y}));
     AT_ASSERT(comb->inputs() == std::vector<Value*>({res, t2, t3}));
index 4af5c80..fbf7fe5 100644 (file)
@@ -15,10 +15,17 @@ struct VarWithType;
 struct ParsedLiteral;
 
 class IRParser {
-  friend void parseIR(const std::string& str, torch::jit::Graph* graph);
-  IRParser(const std::string& str, torch::jit::Graph* graph)
+  friend void parseIR(
+      const std::string& str,
+      torch::jit::Graph* graph,
+      std::unordered_map<std::string, Value*>& vmap);
+  IRParser(
+      const std::string& str,
+      torch::jit::Graph* graph,
+      std::unordered_map<std::string, Value*>& vmap)
       : L(str),
         g(graph),
+        vmap(vmap),
         type_parser(L, /*parse_complete_tensor_types*/ true) {}
 
   std::string parseVar();
@@ -50,7 +57,7 @@ class IRParser {
 
   torch::jit::script::Lexer L;
   torch::jit::Graph* g = nullptr;
-  std::unordered_map<std::string, Value*> vmap;
+  std::unordered_map<std::string, Value*>& vmap;
   SchemaTypeParser type_parser;
 };
 
@@ -73,11 +80,19 @@ struct VarWithType {
   TypePtr type;
 };
 
-void parseIR(const std::string& str, torch::jit::Graph* graph) {
-  torch::jit::script::IRParser p(str, graph);
+void parseIR(
+    const std::string& str,
+    torch::jit::Graph* graph,
+    std::unordered_map<std::string, Value*>& vmap) {
+  torch::jit::script::IRParser p(str, graph, vmap);
   p.parse();
 }
 
+void parseIR(const std::string& str, torch::jit::Graph* graph) {
+  std::unordered_map<std::string, Value*> vmap;
+  parseIR(str, graph, vmap);
+}
+
 VarWithType IRParser::parseVarWithType() {
   VarWithType r;
   r.name = parseVar();
index 2fc835c..d9abc13 100644 (file)
@@ -1,16 +1,28 @@
-#include <string>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <string>
+#include <unordered_map>
 
 namespace torch {
 namespace jit {
 
 struct Graph;
+struct Value;
 
 namespace script {
 
 // \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
 TORCH_API void parseIR(const std::string& str, torch::jit::Graph* graph);
 
+/** \brief Parse IR from \p STR constructing the corresponding IR in\ GRAPH.
+ *
+ * \p VMAP is filled with String to Value pairs allowing to index Values in the
+ * newly created graph by their name in the original IR string.
+ */
+TORCH_API void parseIR(
+    const std::string& str,
+    torch::jit::Graph* graph,
+    std::unordered_map<std::string, Value*>& vmap);
+
 } // namespace script
 } // namespace jit
 } // namespace torch