From a386c28fcd7232ddf45d376812d2d0a5729b292c Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Thu, 31 Jan 2019 15:37:52 -0800 Subject: [PATCH] Remove constant propagation expect files (#16348) Summary: Remove constant prop expect files, and express graph conditions via python bindings. First diff in larger effort to remove expect files Pull Request resolved: https://github.com/pytorch/pytorch/pull/16348 Differential Revision: D13906929 Pulled By: eellison fbshipit-source-id: 7963caa3ccbc7bfc0006a160c952aa173d1ce633 --- .../TestJit.test_constant_prop_if_constant.expect | 28 ------ ...TestJit.test_constant_prop_loop_constant.expect | 17 ---- .../TestJit.test_constant_prop_nested.expect | 15 --- .../expect/TestJit.test_constant_prop_print.expect | 8 -- test/expect/TestJit.test_constant_prop_rand.expect | 11 --- .../TestJit.test_constant_prop_simple.expect | 6 -- test/test_jit.py | 38 +++++--- torch/csrc/jit/python_ir.cpp | 106 +++++++++++++-------- 8 files changed, 91 insertions(+), 138 deletions(-) delete mode 100644 test/expect/TestJit.test_constant_prop_if_constant.expect delete mode 100644 test/expect/TestJit.test_constant_prop_loop_constant.expect delete mode 100644 test/expect/TestJit.test_constant_prop_nested.expect delete mode 100644 test/expect/TestJit.test_constant_prop_print.expect delete mode 100644 test/expect/TestJit.test_constant_prop_rand.expect delete mode 100644 test/expect/TestJit.test_constant_prop_simple.expect diff --git a/test/expect/TestJit.test_constant_prop_if_constant.expect b/test/expect/TestJit.test_constant_prop_if_constant.expect deleted file mode 100644 index d12558f..0000000 --- a/test/expect/TestJit.test_constant_prop_if_constant.expect +++ /dev/null @@ -1,28 +0,0 @@ -graph(%a : Tensor - %b : Tensor) { - %c0.1 : int = prim::Constant[value=1]() - %3 : bool = prim::Bool(%a) - %c0.5 : int, %c1 : int = prim::If(%3) - block0() { - %6 : bool = prim::Bool(%b) - %c0.4 : int = prim::If(%6) - block0() { - %8 : int = prim::Constant[value=2]() - -> (%8) - } - block1() { - -> (%c0.1) - } - -> (%c0.4, %c0.1) - } - block1() { - %9 : int = prim::Constant[value=2]() - -> (%c0.1, %9) - } - %c0.6 : int = aten::add(%c0.5, %c0.1) - %11 : int = prim::Constant[value=5]() - %12 : Tensor = aten::add(%a, %c0.6, %c0.1) - %13 : Tensor = aten::add(%12, %c1, %c0.1) - %14 : Tensor = aten::add(%13, %11, %c0.1) - return (%14); -} diff --git a/test/expect/TestJit.test_constant_prop_loop_constant.expect b/test/expect/TestJit.test_constant_prop_loop_constant.expect deleted file mode 100644 index ff94569..0000000 --- a/test/expect/TestJit.test_constant_prop_loop_constant.expect +++ /dev/null @@ -1,17 +0,0 @@ -graph() { - %0 : bool = prim::Constant[value=0]() - %1 : bool = prim::Constant[value=1]() - %b.1 : int = prim::Constant[value=0]() - %3 : int = prim::Constant[value=9223372036854775807]() - %4 : int = prim::Constant[value=1]() - %5 : int = prim::Constant[value=2]() - %b.2 : int = prim::Loop(%3, %1, %b.1) - block0(%7 : int, %8 : int) { - -> (%1, %4) - } - %b : int = prim::Loop(%3, %0, %b.2) - block0(%10 : int, %11 : int) { - -> (%0, %5) - } - return (%b); -} diff --git a/test/expect/TestJit.test_constant_prop_nested.expect b/test/expect/TestJit.test_constant_prop_nested.expect deleted file mode 100644 index 950fbc1..0000000 --- a/test/expect/TestJit.test_constant_prop_nested.expect +++ /dev/null @@ -1,15 +0,0 @@ -graph(%a : Tensor) { - %1 : int = prim::Constant[value=2]() - %2 : Tensor = aten::lt(%a, %1) - %3 : bool = prim::Bool(%2) - %c : int = prim::If(%3) - block0() { - %5 : int = prim::Constant[value=5]() - -> (%5) - } - block1() { - %6 : int = prim::Constant[value=1]() - -> (%6) - } - return (%c); -} diff --git a/test/expect/TestJit.test_constant_prop_print.expect b/test/expect/TestJit.test_constant_prop_print.expect deleted file mode 100644 index acea521..0000000 --- a/test/expect/TestJit.test_constant_prop_print.expect +++ /dev/null @@ -1,8 +0,0 @@ -graph(%input_tensor : Tensor) { - %1 : int = prim::Constant[value=1]() - %2 : int = prim::Constant[value=6]() - = prim::Print(%2) - %3 : int = prim::Constant[value=8]() - %4 : Tensor = aten::add(%input_tensor, %3, %1) - return (%4); -} diff --git a/test/expect/TestJit.test_constant_prop_rand.expect b/test/expect/TestJit.test_constant_prop_rand.expect deleted file mode 100644 index c2d8a27..0000000 --- a/test/expect/TestJit.test_constant_prop_rand.expect +++ /dev/null @@ -1,11 +0,0 @@ -graph() { - %0 : int = prim::Constant[value=1]() - %1 : Device = prim::Constant[value="cpu"]() - %2 : int = prim::Constant[value=0]() - %3 : int = prim::Constant[value=6]() - %4 : int = prim::Constant[value=2]() - %5 : int[] = prim::Constant[value=[3]]() - %a : Tensor = aten::randn(%5, %3, %2, %1) - %b : Tensor = aten::add(%a, %4, %0) - return (%b); -} diff --git a/test/expect/TestJit.test_constant_prop_simple.expect b/test/expect/TestJit.test_constant_prop_simple.expect deleted file mode 100644 index bb4326e..0000000 --- a/test/expect/TestJit.test_constant_prop_simple.expect +++ /dev/null @@ -1,6 +0,0 @@ -graph(%input_tensor : Tensor) { - %1 : int = prim::Constant[value=1]() - %2 : int = prim::Constant[value=8]() - %3 : Tensor = aten::add(%input_tensor, %2, %1) - return (%3); -} diff --git a/test/test_jit.py b/test/test_jit.py index 95e09ab..33bc980 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1666,17 +1666,20 @@ class TestJit(JitTestCase): def test_constant_prop_simple(self): @torch.jit.script - def constant_prop(input_tensor): + def constant_prop(input_int): + # type: (int) -> int a = 2 * 3 b = a + 2 - return b + input_tensor + return b - input_int - x = torch.tensor(2) - out_ref = constant_prop(x) + out_ref = constant_prop(2) self.run_pass('constant_propagation', constant_prop.graph) - out_test = constant_prop(torch.tensor(2)) + out_test = constant_prop(2) self.assertEqual(out_ref, out_test) - self.assertExpected(canonical(constant_prop.graph)) + graph_str = str(constant_prop.graph) + self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str) + const = constant_prop.graph.findNode("prim::Constant").output().toIValue() + self.assertEqual(const, 8) def test_constant_prop_nested(self): @torch.jit.script @@ -1691,7 +1694,10 @@ class TestJit(JitTestCase): self.run_pass('constant_propagation', constant_prop.graph) out_test = constant_prop(torch.tensor(2)) self.assertEqual(out_ref, out_test) - self.assertExpected(canonical(constant_prop.graph)) + if_node = constant_prop.graph.findNode("prim::If") + for block in if_node.blocks(): + for node in block.nodes(): + self.assertTrue(node.kind() == "prim::Constant") def test_constant_prop_print(self): @torch.jit.script @@ -1702,7 +1708,9 @@ class TestJit(JitTestCase): return b + input_tensor self.run_pass('constant_propagation', constant_prop.graph) - self.assertExpected(canonical(constant_prop.graph)) + graph = constant_prop.graph + print_node = graph.findNode("prim::Print") + self.assertTrue(print_node.input().toIValue() == 6) def test_constant_prop_rand(self): @torch.jit.script @@ -1712,7 +1720,7 @@ class TestJit(JitTestCase): return b self.run_pass('constant_propagation', constant_prop.graph) - self.assertExpected(canonical(constant_prop.graph)) + self.assertTrue("aten::randn" in str(constant_prop.graph)) def test_constant_prop_none(self): @torch.jit.script @@ -1783,8 +1791,16 @@ class TestJit(JitTestCase): c2 = c2 + 4 # set to 5 return a + c0 + c1 + c2 - self.run_pass('constant_propagation', constant_prop.graph) - self.assertExpected(canonical(constant_prop.graph)) + graph = constant_prop.graph + self.run_pass('constant_propagation', graph) + ifs = graph.findAllNodes("prim::If", recurse=False) + snd_if_inlined = len(ifs) == 1 + self.assertTrue(snd_if_inlined) + first_if = ifs[0] + self.assertTrue(first_if.outputsSize() == 2) + second_if = first_if.findNode("prim::If", recurse=False) + self.assertTrue(second_if.outputsSize() == 1) + self.assertTrue(second_if.findNode("prim::If") is None) def test_constant_prop_loop_constant(self): @torch.jit.script diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 3e73bde..1c88709 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -69,35 +69,55 @@ std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) { } } -std::vector findAllNodes(Block* block, Symbol kind) { +std::vector findAllNodes( + c10::ArrayRef blocks, + Symbol kind, + bool recurse = true) { std::vector ret; - for (Node* n : block->nodes()) { - for (Block* b : n->blocks()) { - auto nodes = findAllNodes(b, kind); - ret.insert(ret.end(), nodes.begin(), nodes.end()); - } - if (n->kind() == kind) { - ret.push_back(n); + for (Block* block : blocks) { + for (Node* n : block->nodes()) { + if (n->kind() == kind) { + ret.push_back(n); + } + if (recurse) { + auto nodes = findAllNodes(n->blocks(), kind, recurse); + ret.insert(ret.end(), nodes.begin(), nodes.end()); + } } } return ret; } -Node* findNode(Block* block, Symbol kind) { - for (Node* n : block->nodes()) { - for (Block* b : n->blocks()) { - auto node = findNode(b, kind); - if (node != nullptr) { - return node; +std::vector findAllNodes(Block* block, Symbol kind, bool recurse = true) { + std::vector blocks = {block}; + return findAllNodes(blocks, kind, recurse); +} + +Node* findNode( + c10::ArrayRef blocks, + Symbol kind, + bool recurse = true) { + for (Block* block : blocks) { + for (Node* n : block->nodes()) { + if (n->kind() == kind) { + return n; + } + if (recurse) { + auto node = findNode(n->blocks(), kind, recurse); + if (node != nullptr) { + return node; + } } - } - if (n->kind() == kind) { - return n; } } return nullptr; } +Node* findNode(Block* block, Symbol kind, bool recurse = true) { + std::vector blocks = {block}; + return findNode(blocks, kind, recurse); +} + // execute a Python function, used for Ops we can't optimize but that we want to // optimize around struct ConcretePythonOp : public PythonOp { @@ -269,14 +289,15 @@ void initPythonIRBindings(PyObject* module_) { }) .def( "findNode", - [](Graph& g, const std::string& kind) { - return findNode(g.block(), Symbol::fromQualString(kind)); - }) + [](Graph& g, const std::string& kind, bool recurse) { + return findNode(g.block(), Symbol::fromQualString(kind), recurse); + }, "Find Node", py::arg("kind"), py::arg("recurse") = true) .def( "findAllNodes", - [](Graph& g, const std::string& kind) { - return findAllNodes(g.block(), Symbol::fromQualString(kind)); - }) + [](Graph& g, const std::string& kind, bool recurse) { + return findAllNodes( + g.block(), Symbol::fromQualString(kind), recurse); + }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true) .def("addInput", [](Graph& g) { return g.addInput(); }) .def("copy", [](Graph& g) { return g.copy(); }) .GS(eraseInput) @@ -361,7 +382,18 @@ void initPythonIRBindings(PyObject* module_) { py::class_>(m, "Block") .def("nodes", [](Block& b) { return py::make_iterator(b.nodes().begin(), b.nodes().end()); - }); + }) + .def( + "findNode", + [](Block& b, const std::string& kind, bool recurse) { + return findNode(&b, Symbol::fromQualString(kind), recurse); + }, "Find Node", py::arg("kind"), py::arg("recurse") = true) + .def( + "findAllNodes", + [](Block& b, const std::string& kind, bool recurse) { + return findAllNodes(&b, Symbol::fromQualString(kind), recurse); + }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true); + #define NS(name) def(#name, &Node ::name) py::class_>(m, "Node") @@ -400,26 +432,16 @@ void initPythonIRBindings(PyObject* module_) { .def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); }) .def( "findNode", - [](Node& n, const std::string& kind) { - Node* node; - for (Block* b : n.blocks()) { - node = findNode(b, Symbol::fromQualString(kind)); - if (node != nullptr) { - return node; - } - } - return node; - }) + [](Node& n, const std::string& kind, bool recurse) { + return findNode(n.blocks(), Symbol::fromQualString(kind), recurse); + }, "Find Node", py::arg("kind"), py::arg("recurse") = true) .def( "findAllNodes", - [](Node& n, const std::string& kind) { - std::vector ret; - for (Block* b : n.blocks()) { - auto nodes = findAllNodes(b, Symbol::fromQualString(kind)); - ret.insert(ret.end(), nodes.begin(), nodes.end()); - } - return ret; - }) + [](Node& n, const std::string& kind, bool recurse) { + return findAllNodes( + n.blocks(), Symbol::fromQualString(kind), recurse); + }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true) + .def("input", [](Node& n) { return n.input(); }) .def("output", [](Node& n) { return n.output(); }) .NS(addInput) .NS(replaceInput) -- 2.7.4