Remove constant propagation expect files (#16348)
authorElias Ellison <eellison@fb.com>
Thu, 31 Jan 2019 23:37:52 +0000 (15:37 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 31 Jan 2019 23:41:22 +0000 (15:41 -0800)
Summary:
Remove constant prop expect files, and express graph conditions via python bindings.

First diff in larger effort to remove expect files
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16348

Differential Revision: D13906929

Pulled By: eellison

fbshipit-source-id: 7963caa3ccbc7bfc0006a160c952aa173d1ce633

test/expect/TestJit.test_constant_prop_if_constant.expect [deleted file]
test/expect/TestJit.test_constant_prop_loop_constant.expect [deleted file]
test/expect/TestJit.test_constant_prop_nested.expect [deleted file]
test/expect/TestJit.test_constant_prop_print.expect [deleted file]
test/expect/TestJit.test_constant_prop_rand.expect [deleted file]
test/expect/TestJit.test_constant_prop_simple.expect [deleted file]
test/test_jit.py
torch/csrc/jit/python_ir.cpp

diff --git a/test/expect/TestJit.test_constant_prop_if_constant.expect b/test/expect/TestJit.test_constant_prop_if_constant.expect
deleted file mode 100644 (file)
index d12558f..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-graph(%a : Tensor
-      %b : Tensor) {
-  %c0.1 : int = prim::Constant[value=1]()
-  %3 : bool = prim::Bool(%a)
-  %c0.5 : int, %c1 : int = prim::If(%3)
-    block0() {
-      %6 : bool = prim::Bool(%b)
-      %c0.4 : int = prim::If(%6)
-        block0() {
-          %8 : int = prim::Constant[value=2]()
-          -> (%8)
-        }
-        block1() {
-          -> (%c0.1)
-        }
-      -> (%c0.4, %c0.1)
-    }
-    block1() {
-      %9 : int = prim::Constant[value=2]()
-      -> (%c0.1, %9)
-    }
-  %c0.6 : int = aten::add(%c0.5, %c0.1)
-  %11 : int = prim::Constant[value=5]()
-  %12 : Tensor = aten::add(%a, %c0.6, %c0.1)
-  %13 : Tensor = aten::add(%12, %c1, %c0.1)
-  %14 : Tensor = aten::add(%13, %11, %c0.1)
-  return (%14);
-}
diff --git a/test/expect/TestJit.test_constant_prop_loop_constant.expect b/test/expect/TestJit.test_constant_prop_loop_constant.expect
deleted file mode 100644 (file)
index ff94569..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-graph() {
-  %0 : bool = prim::Constant[value=0]()
-  %1 : bool = prim::Constant[value=1]()
-  %b.1 : int = prim::Constant[value=0]()
-  %3 : int = prim::Constant[value=9223372036854775807]()
-  %4 : int = prim::Constant[value=1]()
-  %5 : int = prim::Constant[value=2]()
-  %b.2 : int = prim::Loop(%3, %1, %b.1)
-    block0(%7 : int, %8 : int) {
-      -> (%1, %4)
-    }
-  %b : int = prim::Loop(%3, %0, %b.2)
-    block0(%10 : int, %11 : int) {
-      -> (%0, %5)
-    }
-  return (%b);
-}
diff --git a/test/expect/TestJit.test_constant_prop_nested.expect b/test/expect/TestJit.test_constant_prop_nested.expect
deleted file mode 100644 (file)
index 950fbc1..0000000
+++ /dev/null
@@ -1,15 +0,0 @@
-graph(%a : Tensor) {
-  %1 : int = prim::Constant[value=2]()
-  %2 : Tensor = aten::lt(%a, %1)
-  %3 : bool = prim::Bool(%2)
-  %c : int = prim::If(%3)
-    block0() {
-      %5 : int = prim::Constant[value=5]()
-      -> (%5)
-    }
-    block1() {
-      %6 : int = prim::Constant[value=1]()
-      -> (%6)
-    }
-  return (%c);
-}
diff --git a/test/expect/TestJit.test_constant_prop_print.expect b/test/expect/TestJit.test_constant_prop_print.expect
deleted file mode 100644 (file)
index acea521..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-graph(%input_tensor : Tensor) {
-  %1 : int = prim::Constant[value=1]()
-  %2 : int = prim::Constant[value=6]()
-   = prim::Print(%2)
-  %3 : int = prim::Constant[value=8]()
-  %4 : Tensor = aten::add(%input_tensor, %3, %1)
-  return (%4);
-}
diff --git a/test/expect/TestJit.test_constant_prop_rand.expect b/test/expect/TestJit.test_constant_prop_rand.expect
deleted file mode 100644 (file)
index c2d8a27..0000000
+++ /dev/null
@@ -1,11 +0,0 @@
-graph() {
-  %0 : int = prim::Constant[value=1]()
-  %1 : Device = prim::Constant[value="cpu"]()
-  %2 : int = prim::Constant[value=0]()
-  %3 : int = prim::Constant[value=6]()
-  %4 : int = prim::Constant[value=2]()
-  %5 : int[] = prim::Constant[value=[3]]()
-  %a : Tensor = aten::randn(%5, %3, %2, %1)
-  %b : Tensor = aten::add(%a, %4, %0)
-  return (%b);
-}
diff --git a/test/expect/TestJit.test_constant_prop_simple.expect b/test/expect/TestJit.test_constant_prop_simple.expect
deleted file mode 100644 (file)
index bb4326e..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-graph(%input_tensor : Tensor) {
-  %1 : int = prim::Constant[value=1]()
-  %2 : int = prim::Constant[value=8]()
-  %3 : Tensor = aten::add(%input_tensor, %2, %1)
-  return (%3);
-}
index 95e09ab..33bc980 100644 (file)
@@ -1666,17 +1666,20 @@ class TestJit(JitTestCase):
 
     def test_constant_prop_simple(self):
         @torch.jit.script
-        def constant_prop(input_tensor):
+        def constant_prop(input_int):
+            # type: (int) -> int
             a = 2 * 3
             b = a + 2
-            return b + input_tensor
+            return b - input_int
 
-        x = torch.tensor(2)
-        out_ref = constant_prop(x)
+        out_ref = constant_prop(2)
         self.run_pass('constant_propagation', constant_prop.graph)
-        out_test = constant_prop(torch.tensor(2))
+        out_test = constant_prop(2)
         self.assertEqual(out_ref, out_test)
-        self.assertExpected(canonical(constant_prop.graph))
+        graph_str = str(constant_prop.graph)
+        self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
+        const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
+        self.assertEqual(const, 8)
 
     def test_constant_prop_nested(self):
         @torch.jit.script
@@ -1691,7 +1694,10 @@ class TestJit(JitTestCase):
         self.run_pass('constant_propagation', constant_prop.graph)
         out_test = constant_prop(torch.tensor(2))
         self.assertEqual(out_ref, out_test)
-        self.assertExpected(canonical(constant_prop.graph))
+        if_node = constant_prop.graph.findNode("prim::If")
+        for block in if_node.blocks():
+            for node in block.nodes():
+                self.assertTrue(node.kind() == "prim::Constant")
 
     def test_constant_prop_print(self):
         @torch.jit.script
@@ -1702,7 +1708,9 @@ class TestJit(JitTestCase):
             return b + input_tensor
 
         self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+        graph = constant_prop.graph
+        print_node = graph.findNode("prim::Print")
+        self.assertTrue(print_node.input().toIValue() == 6)
 
     def test_constant_prop_rand(self):
         @torch.jit.script
@@ -1712,7 +1720,7 @@ class TestJit(JitTestCase):
             return b
 
         self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+        self.assertTrue("aten::randn" in str(constant_prop.graph))
 
     def test_constant_prop_none(self):
         @torch.jit.script
@@ -1783,8 +1791,16 @@ class TestJit(JitTestCase):
                 c2 = c2 + 4  # set to 5
             return a + c0 + c1 + c2
 
-        self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+        graph = constant_prop.graph
+        self.run_pass('constant_propagation', graph)
+        ifs = graph.findAllNodes("prim::If", recurse=False)
+        snd_if_inlined = len(ifs) == 1
+        self.assertTrue(snd_if_inlined)
+        first_if = ifs[0]
+        self.assertTrue(first_if.outputsSize() == 2)
+        second_if = first_if.findNode("prim::If", recurse=False)
+        self.assertTrue(second_if.outputsSize() == 1)
+        self.assertTrue(second_if.findNode("prim::If") is None)
 
     def test_constant_prop_loop_constant(self):
         @torch.jit.script
index 3e73bde..1c88709 100644 (file)
@@ -69,35 +69,55 @@ std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
   }
 }
 
-std::vector<Node*> findAllNodes(Block* block, Symbol kind) {
+std::vector<Node*> findAllNodes(
+    c10::ArrayRef<torch::jit::Block*> blocks,
+    Symbol kind,
+    bool recurse = true) {
   std::vector<Node*> ret;
-  for (Node* n : block->nodes()) {
-    for (Block* b : n->blocks()) {
-      auto nodes = findAllNodes(b, kind);
-      ret.insert(ret.end(), nodes.begin(), nodes.end());
-    }
-    if (n->kind() == kind) {
-      ret.push_back(n);
+  for (Block* block : blocks) {
+    for (Node* n : block->nodes()) {
+      if (n->kind() == kind) {
+        ret.push_back(n);
+      }
+      if (recurse) {
+        auto nodes = findAllNodes(n->blocks(), kind, recurse);
+        ret.insert(ret.end(), nodes.begin(), nodes.end());
+      }
     }
   }
   return ret;
 }
 
-Node* findNode(Block* block, Symbol kind) {
-  for (Node* n : block->nodes()) {
-    for (Block* b : n->blocks()) {
-      auto node = findNode(b, kind);
-      if (node != nullptr) {
-        return node;
+std::vector<Node*> findAllNodes(Block* block, Symbol kind, bool recurse = true) {
+  std::vector<Block*> blocks = {block};
+  return findAllNodes(blocks, kind, recurse);
+}
+
+Node* findNode(
+    c10::ArrayRef<torch::jit::Block*> blocks,
+    Symbol kind,
+    bool recurse = true) {
+  for (Block* block : blocks) {
+    for (Node* n : block->nodes()) {
+      if (n->kind() == kind) {
+        return n;
+      }
+      if (recurse) {
+        auto node = findNode(n->blocks(), kind, recurse);
+        if (node != nullptr) {
+          return node;
+        }
       }
-    }
-    if (n->kind() == kind) {
-      return n;
     }
   }
   return nullptr;
 }
 
+Node* findNode(Block* block, Symbol kind, bool recurse = true) {
+  std::vector<Block*> blocks = {block};
+  return findNode(blocks, kind, recurse);
+}
+
 // execute a Python function, used for Ops we can't optimize but that we want to
 // optimize around
 struct ConcretePythonOp : public PythonOp {
@@ -269,14 +289,15 @@ void initPythonIRBindings(PyObject* module_) {
           })
       .def(
           "findNode",
-          [](Graph& g, const std::string& kind) {
-            return findNode(g.block(), Symbol::fromQualString(kind));
-          })
+          [](Graph& g, const std::string& kind, bool recurse) {
+            return findNode(g.block(), Symbol::fromQualString(kind), recurse);
+          }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
       .def(
           "findAllNodes",
-          [](Graph& g, const std::string& kind) {
-            return findAllNodes(g.block(), Symbol::fromQualString(kind));
-          })
+          [](Graph& g, const std::string& kind, bool recurse) {
+            return findAllNodes(
+                g.block(), Symbol::fromQualString(kind), recurse);
+          }, "Find all nodes",  py::arg("kind"), py::arg("recurse") = true)
       .def("addInput", [](Graph& g) { return g.addInput(); })
       .def("copy", [](Graph& g) { return g.copy(); })
       .GS(eraseInput)
@@ -361,7 +382,18 @@ void initPythonIRBindings(PyObject* module_) {
   py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
       .def("nodes", [](Block& b) {
         return py::make_iterator(b.nodes().begin(), b.nodes().end());
-      });
+      })
+      .def(
+          "findNode",
+          [](Block& b, const std::string& kind, bool recurse) {
+            return findNode(&b, Symbol::fromQualString(kind), recurse);
+          }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+      .def(
+          "findAllNodes",
+          [](Block& b, const std::string& kind, bool recurse) {
+            return findAllNodes(&b, Symbol::fromQualString(kind), recurse);
+          }, "Find all nodes",  py::arg("kind"), py::arg("recurse") = true);
+
 
 #define NS(name) def(#name, &Node ::name)
   py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
@@ -400,26 +432,16 @@ void initPythonIRBindings(PyObject* module_) {
       .def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
       .def(
           "findNode",
-          [](Node& n, const std::string& kind) {
-            Node* node;
-            for (Block* b : n.blocks()) {
-              node = findNode(b, Symbol::fromQualString(kind));
-              if (node != nullptr) {
-                return node;
-              }
-            }
-            return node;
-          })
+          [](Node& n, const std::string& kind, bool recurse) {
+            return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
+          }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
       .def(
           "findAllNodes",
-          [](Node& n, const std::string& kind) {
-            std::vector<Node*> ret;
-            for (Block* b : n.blocks()) {
-              auto nodes = findAllNodes(b, Symbol::fromQualString(kind));
-              ret.insert(ret.end(), nodes.begin(), nodes.end());
-            }
-            return ret;
-          })
+          [](Node& n, const std::string& kind, bool recurse) {
+            return findAllNodes(
+                n.blocks(), Symbol::fromQualString(kind), recurse);
+          }, "Find all nodes",  py::arg("kind"), py::arg("recurse") = true)
+      .def("input", [](Node& n) { return n.input(); })
       .def("output", [](Node& n) { return n.output(); })
       .NS(addInput)
       .NS(replaceInput)