Constant propagation changes (#16244)
authorElias Ellison <eellison@fb.com>
Thu, 24 Jan 2019 01:47:29 +0000 (17:47 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 24 Jan 2019 01:50:33 +0000 (17:50 -0800)
Summary:
- remove loop node that is guaranteed not to execute
- remove extra loop outputs that are no longer needed

- if we are inlining an if node, only run constant propagation on the block that will execute

- remove the recurse argument since we only expose the Graph Constant Propagation and it's not used

This also includes  a few extra hooks to python_ir that I think make it a little be easier to test graph conditions from python.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16244

Differential Revision: D13791635

Pulled By: eellison

fbshipit-source-id: d16351fffcfc8013b02015db200f8fde002e0577

test/test_jit.py
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/python_ir.cpp

index b26521f..63df8cf 100644 (file)
@@ -1719,6 +1719,20 @@ class TestJit(JitTestCase):
         graph_str = str(constant_prop.graph)
         self.assertTrue(graph_str.count("prim::None") == 0)
 
+    def test_constant_prop_if_inline(self):
+        @torch.jit.script
+        def constant_prop():
+            cond = True
+            a = 1
+            if cond:
+                a = 1 * 2
+            else:
+                a = 1 // 0
+            return a
+
+        # testing that 1 // 0 error is not thrownn
+        self.run_pass('constant_propagation', constant_prop.graph)
+
     def test_trace_records_names(self):
         def foo(bar, baz):
             baz = bar + 3
@@ -1759,16 +1773,49 @@ class TestJit(JitTestCase):
 
     def test_constant_prop_loop_constant(self):
         @torch.jit.script
-        def constant_prop():
+        def constant_prop(cond, iter):
+            # type: (bool, int) -> int
             b = 0
             while True:
-                b = 1
+                print("stays")
+            for _ in range(2):
+                print("stays")
+            for _ in range(iter):
+                print("stays")
+            while cond:
+                print("stays")
             while False:
-                b = 2
+                print("removed")
+            for _i in range(0):
+                print("removed")
+            for _i in range(-4):
+                print("removed")
             return b
 
         self.run_pass('constant_propagation', constant_prop.graph)
-        self.assertExpected(canonical(constant_prop.graph))
+        graph = canonical(constant_prop.graph)
+        self.assertTrue(graph.count("removed") == 0)
+        self.assertTrue(graph.count("stays") == 1)  # constant gets pooled
+        self.assertTrue(graph.count("prim::Print") == 4)
+
+    def test_constant_prop_remove_output(self):
+        @torch.jit.script
+        def constant_prop(iter):
+            # type: (int) -> None
+            a = 1
+            b = 1
+            c = 1
+            for i in range(iter):
+                if False:
+                    a = 10
+                if i == 5:
+                    b = 2
+                    c = 3
+            print(a, b, c)
+
+        graph = constant_prop.graph
+        self.run_pass('constant_propagation', graph)
+        self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
 
     def test_trace_detach(self):
         def foo(x, w):
index 3951eb5..f580a94 100644 (file)
@@ -1,3 +1,4 @@
+#include <torch/csrc/jit/passes/constant_propagation.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/interpreter.h>
@@ -5,7 +6,6 @@
 #include <torch/csrc/jit/ivalue.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
-#include <torch/csrc/jit/passes/constant_propagation.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/utils/functional.h>
 
@@ -16,10 +16,10 @@ namespace {
 
 std::unordered_set<Symbol> skip_list = {
     prim::If,
-    prim::Loop, // TODO: handle Loop
+    prim::Loop,
     prim::Constant,
     prim::Undefined,
-    prim::unchecked_unwrap_optional, //TODO remove
+    prim::unchecked_unwrap_optional, // TODO remove
     prim::None, // it is already a constant and propagating it will lose
                 // important type information about which Optional type it is
     // TODO (zach): we should consider skipping tensor factories in the cases
@@ -75,7 +75,30 @@ void propagateNode(Node* n) {
   }
 }
 
-void inlineIf(Block* body, Node* n) {
+void removeLoopNode(Node* n) {
+  auto loop_input_offset = 2; // offset of loop carried deps in input list
+  for (size_t i = 0; i < n->outputs().size(); ++i) {
+    n->outputs().at(i)->replaceAllUsesWith(
+        n->inputs().at(i + loop_input_offset));
+  }
+  n->destroy();
+}
+
+bool loopWillNotRun(Node* node) {
+  Value* trip_count = node->inputs().at(0);
+  int64_t iter_len = constant_as<int64_t>(trip_count).value_or(1);
+
+  Value* start_cond = node->inputs().at(1);
+  bool cond_val = constant_as<bool>(start_cond).value_or(true);
+
+  bool loop_might_run = cond_val && iter_len > 0;
+  return !loop_might_run;
+}
+
+void ConstantPropagation(Block* block, const AliasDb& aliasDb);
+
+void inlineIfBody(Block* body) {
+  Node* n = body->owningNode();
   for (auto it = body->nodes().begin(); it != body->nodes().end();) {
     Node* body_node = *it;
     // advance iterator because after body_node is moved its next pointer will
@@ -91,22 +114,16 @@ void inlineIf(Block* body, Node* n) {
   n->destroy();
 }
 
-bool isTrueConstant(Value* val) {
-  c10::optional<bool> maybe_value = constant_as<bool>(val);
-  JIT_ASSERT(maybe_value);
-  return *maybe_value;
-}
-
-void inlineIf(Node* n) {
-  if (isTrueConstant(n->input())) {
-    inlineIf(n->blocks()[0], n);
-  } else {
-    inlineIf(n->blocks()[1], n);
-  }
+void inlineIf(Node* n, const AliasDb& aliasDb) {
+  auto input_bool = constant_as<bool>(n->input());
+  JIT_ASSERT(input_bool);
+  size_t block_index = *input_bool ? 0 : 1;
+  ConstantPropagation(n->blocks().at(block_index), aliasDb);
+  inlineIfBody(n->blocks().at(block_index));
 }
 
 // remove extra outputs from the node
-bool removeExtraNodeOutputs(Node* n) {
+bool removeExtraIfOutputs(Node* n) {
   JIT_ASSERTM(n->kind() == prim::If, "Only supported for If nodes");
   auto true_block = n->blocks()[0];
   auto false_block = n->blocks()[1];
@@ -126,53 +143,78 @@ bool removeExtraNodeOutputs(Node* n) {
   return initial_outputs != true_block->outputs().size();
 }
 
-void ConstantPropagation(Block* block, const AliasDb& aliasDb, bool recurse);
+// remove extra outputs from the node
+void removeExtraLoopOutputs(Node* node) {
+  auto loop_body = node->blocks().at(0);
+  auto loop_input_offset = 2; // offset of loop carried deps in input list
+  auto loop_body_offset =
+      1; // offset to the loop carried dependencies in block inputs/outputs
+  for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
+    size_t i = i_1 - 1;
+    // if the value is no longer changed remove output
+    if (loop_body->inputs().at(loop_body_offset + i) ==
+        loop_body->outputs().at(loop_body_offset + i)) {
+      auto node_input = node->inputs().at(loop_input_offset + i);
+      node->outputs().at(i)->replaceAllUsesWith(node_input);
+      loop_body->inputs()
+          .at(loop_body_offset + i)
+          ->replaceAllUsesWith(node_input);
+      node->eraseOutput(i);
+      node->removeInput(loop_input_offset + i);
+      loop_body->eraseInput(loop_body_offset + i);
+      loop_body->eraseOutput(loop_body_offset + i);
+    }
+  }
+}
 
-void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
+void ConstantPropagation(Node* n, const AliasDb& aliasDb) {
   bool constant_inputs =
       std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
-        return v->node()->kind() == prim::Constant || v->node()->kind() == prim::None;
+        return v->node()->kind() == prim::Constant ||
+            v->node()->kind() == prim::None;
       });
   bool supported_node = !n->kind().is_onnx() &&
       skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
       !n->hasSideEffects() && !aliasDb.hasWriters(n);
   auto run_blocks = [&]() {
-    if (recurse) {
-      for (Block* block : n->blocks()) {
-        ConstantPropagation(block, aliasDb, recurse);
-      }
+    for (Block* block : n->blocks()) {
+      ConstantPropagation(block, aliasDb);
     }
   };
   if (n->kind() == prim::If) {
-    run_blocks();
     // inline node if we can, otherwise check for simplified outputs
     if (constant_inputs) {
-      inlineIf(n);
+      inlineIf(n, aliasDb);
     } else {
-      removeExtraNodeOutputs(n);
+      run_blocks();
+      removeExtraIfOutputs(n);
+    }
+  } else if (n->kind() == prim::Loop) {
+    if (loopWillNotRun(n)) {
+      removeLoopNode(n);
+    } else {
+      run_blocks();
+      removeExtraLoopOutputs(n);
     }
-    // don't rerun run_blocks
-    return;
   } else if (constant_inputs && supported_node) {
     propagateNode(n);
+  } else {
+    run_blocks();
   }
-  // TODO handle loop nodes. Even if a loop node contains an if that is
-  // inlined its mutated variables currently don't get updated
-  run_blocks();
 }
 
-void ConstantPropagation(Block* block, const AliasDb& aliasDb, bool recurse) {
+void ConstantPropagation(Block* block, const AliasDb& aliasDb) {
   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
     Node* n = *it;
     it++; // advance iterator bc the current node may be destroyed
-    ConstantPropagation(n, aliasDb, recurse);
+    ConstantPropagation(n, aliasDb);
   }
 }
 } // anonymous namespace
 
 void ConstantPropagation(std::shared_ptr<Graph>& graph) {
   const auto aliasDb = AliasAnalysis(graph);
-  ConstantPropagation(graph->block(), aliasDb, true);
+  ConstantPropagation(graph->block(), aliasDb);
   EliminateDeadCode(graph);
 }
 
index e79ff82..bb628a6 100644 (file)
@@ -68,6 +68,35 @@ std::ostream& printPyObject(std::ostream& out, const THPObjectPtr& obj) {
   }
 }
 
+std::vector<Node*> findAllNodes(Block* block, Symbol kind) {
+  std::vector<Node*> ret;
+  for (Node* n : block->nodes()) {
+    for (Block* b : n->blocks()) {
+      auto nodes = findAllNodes(b, kind);
+      ret.insert(ret.end(), nodes.begin(), nodes.end());
+    }
+    if (n->kind() == kind) {
+      ret.push_back(n);
+    }
+  }
+  return ret;
+}
+
+Node* findNode(Block* block, Symbol kind) {
+  for (Node* n : block->nodes()) {
+    for (Block* b : n->blocks()) {
+      auto node = findNode(b, kind);
+      if (node != nullptr) {
+        return node;
+      }
+    }
+    if (n->kind() == kind) {
+      return n;
+    }
+  }
+  return nullptr;
+}
+
 // execute a Python function, used for Ops we can't optimize but that we want to
 // optimize around
 struct ConcretePythonOp : public PythonOp {
@@ -231,6 +260,16 @@ void initPythonIRBindings(PyObject* module_) {
           [](Graph& g) {
             return py::make_iterator(g.nodes().begin(), g.nodes().end());
           })
+      .def(
+          "findNode",
+          [](Graph& g, const std::string& kind) {
+            return findNode(g.block(), Symbol::fromQualString(kind));
+          })
+      .def(
+          "findAllNodes",
+          [](Graph& g, const std::string& kind) {
+            return findAllNodes(g.block(), Symbol::fromQualString(kind));
+          })
       .def("addInput", [](Graph& g) { return g.addInput(); })
       .def("copy", [](Graph& g) { return g.copy(); })
       .GS(eraseInput)
@@ -308,8 +347,8 @@ void initPythonIRBindings(PyObject* module_) {
             return node;
           })
       .VS(copyMetadata)
-      .VS(isTensor);
-
+      .VS(isTensor)
+      .def("toIValue", [](Value& n) { return toIValue(&n); });
 #undef VS
 
   py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
@@ -340,6 +379,7 @@ void initPythonIRBindings(PyObject* module_) {
       .def("hasMultipleOutputs", [](Node& n) { return n.outputs().size() > 1; })
       .def("outputsSize", [](Node& n) { return n.outputs().size(); })
       .NS(kind)
+      .def("inputsAt", [](Node& n, size_t i) { return n.inputs().at(i); })
       .def(
           "inputs",
           [](Node& n) {
@@ -350,6 +390,29 @@ void initPythonIRBindings(PyObject* module_) {
           [](Node& n) {
             return py::make_iterator(n.outputs().begin(), n.outputs().end());
           })
+      .def("outputsAt", [](Node& n, size_t i) { return n.outputs().at(i); })
+      .def(
+          "findNode",
+          [](Node& n, const std::string& kind) {
+            Node* node;
+            for (Block* b : n.blocks()) {
+              node = findNode(b, Symbol::fromQualString(kind));
+              if (node != nullptr) {
+                return node;
+              }
+            }
+            return node;
+          })
+      .def(
+          "findAllNodes",
+          [](Node& n, const std::string& kind) {
+            std::vector<Node*> ret;
+            for (Block* b : n.blocks()) {
+              auto nodes = findAllNodes(b, Symbol::fromQualString(kind));
+              ret.insert(ret.end(), nodes.begin(), nodes.end());
+            }
+            return ret;
+          })
       .def("output", [](Node& n) { return n.output(); })
       .NS(addInput)
       .NS(replaceInput)
@@ -542,8 +605,7 @@ void initPythonIRBindings(PyObject* module_) {
         return types;
       });
   py::class_<ListType, Type, std::shared_ptr<ListType>>(m, "ListType")
-      .def(
-          py::init([](TypePtr a) { return ListType::create(a); }))
+      .def(py::init([](TypePtr a) { return ListType::create(a); }))
       .def_static("ofInts", &ListType::ofInts)
       .def_static("ofTensors", &ListType::ofTensors)
       .def("getElementType", &ListType::getElementType);