Optimize boolean expressions & unwraps (#18259)
authoreellison <elias_ellison@brown.edu>
Tue, 26 Mar 2019 04:48:11 +0000 (21:48 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 04:50:57 +0000 (21:50 -0700)
Summary:
Simplify or eliminate boolean and/or expressions, optimize unwrapping a value that cannot be None, and optimize using `is` with a None and a non-None value

Since peephole optimize is now introducing constants, i added another constant propagation pass after running it.

Previously i had a PR that did this & optimized shape ops - i will add the shape optimizations in a separate PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18259

Differential Revision: D14602749

Pulled By: eellison

fbshipit-source-id: 1c3f5a67067d8dfdf55d7b78dcb616472ea8a267

test/cpp/jit/test.cpp
test/cpp/jit/test_peephole_optimize.h [new file with mode: 0644]
test/test_jit.py
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/peephole.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/schema_type_parser.cpp

index 1abb5ed..1c4823d 100644 (file)
@@ -24,6 +24,7 @@
 #include <test/cpp/jit/test_ivalue.h>
 #include <test/cpp/jit/test_misc.h>
 #include <test/cpp/jit/test_netdef_converter.h>
+#include <test/cpp/jit/test_peephole_optimize.h>
 #include <test/cpp/jit/test_subgraph_utils.h>
 
 using namespace torch::jit::script;
@@ -61,7 +62,8 @@ namespace jit {
   _(THNNConv)                      \
   _(ATenNativeBatchNorm)           \
   _(NoneSchemaMatch)               \
-  _(ClassParser)
+  _(ClassParser)                   \
+  _(PeepholeOptimize)
 
 #define TH_FORALL_TESTS_CUDA(_) \
   _(ArgumentSpec)               \
diff --git a/test/cpp/jit/test_peephole_optimize.h b/test/cpp/jit/test_peephole_optimize.h
new file mode 100644 (file)
index 0000000..32aacf4
--- /dev/null
@@ -0,0 +1,104 @@
+#pragma once
+
+#include <test/cpp/jit/test_base.h>
+#include <test/cpp/jit/test_utils.h>
+
+#include <torch/csrc/jit/irparser.h>
+#include <torch/csrc/jit/passes/peephole.h>
+
+namespace torch {
+namespace jit {
+
+using namespace script;
+using namespace testing;
+
+namespace test {
+
+void testPeepholeOptimize() {
+  // test is / is not none optimization
+  {
+    auto graph = std::make_shared<Graph>();
+    parseIR(
+        R"IR(
+graph(%0 : int):
+  %1 : None = prim::Constant()
+  %2 : bool = aten::__is__(%0, %1)
+  %3 : bool = aten::__isnot__(%0, %1)
+  return (%2, %3)
+  )IR",
+        graph.get());
+    PeepholeOptimize(graph);
+    testing::FileCheck()
+        .check_not("aten::__is__")
+        ->check_not("aten::__isnot__")
+        ->run(*graph);
+  }
+  {
+    auto graph = std::make_shared<Graph>();
+    parseIR(
+        R"IR(
+graph(%0: int?):
+  %1 : None = prim::Constant()
+  %2 : bool = aten::__is__(%0, %1)
+  %3 : bool = aten::__isnot__(%0, %1)
+  return (%2, %3)
+  )IR",
+        graph.get());
+    PeepholeOptimize(graph);
+    testing::FileCheck()
+        .check("aten::__is__")
+        ->check("aten::__isnot__")
+        ->run(*graph);
+  }
+
+  {
+    auto graph = std::make_shared<Graph>();
+    parseIR(
+        R"IR(
+graph(%0: int?):
+  %1 : Tensor = prim::AutogradZero()
+  %2 : None = prim::Constant()
+  %4 : bool = aten::__is__(%0, %1)
+  %5 : bool = aten::__isnot__(%1, %2)
+  return (%4, %5)
+  )IR",
+        graph.get());
+    PeepholeOptimize(graph);
+    testing::FileCheck()
+        .check("aten::__is__")
+        ->check_not("aten::__isnot__")
+        ->run(*graph);
+  }
+
+  // test unwrap optional
+  {
+    auto graph = std::make_shared<Graph>();
+    parseIR(
+        R"IR(
+graph():
+  %1 : Float(*, *, *) = prim::Constant()
+  %2 : bool = aten::_unwrap_optional(%1)
+  %3 : bool = prim::unchecked_unwrap_optional(%1)
+  return (%2, %3)
+  )IR",
+        graph.get());
+    PeepholeOptimize(graph);
+    testing::FileCheck().check_not("unwrap")->run(*graph);
+  }
+  {
+    auto graph = std::make_shared<Graph>();
+    parseIR(
+        R"IR(
+graph(%1 : Float(*, *, *)?):
+  %2 : bool = aten::_unwrap_optional(%1)
+  %3 : bool = prim::unchecked_unwrap_optional(%1)
+  return (%2, %3)
+  )IR",
+        graph.get());
+    PeepholeOptimize(graph);
+    testing::FileCheck().check_count("unwrap", 2)->run(*graph);
+  }
+}
+} // namespace test
+} // namespace jit
+} // namespace torch
index c2318e1..6ef2c1f 100644 (file)
@@ -1881,6 +1881,26 @@ class TestJit(JitTestCase):
         # testing that 1 // 0 error is not thrownn
         self.run_pass('constant_propagation', constant_prop.graph)
 
+    def test_short_circuit_optimization(self):
+        @torch.jit.script
+        def const_expressions(x):
+            # type: (int) -> Tuple[bool, bool]
+            return x == 1 and False, x == 1 or True
+        self.run_pass('constant_propagation', const_expressions.graph)
+        FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
+        self.assertEqual(const_expressions(1), (False, True))
+
+        @torch.jit.script
+        def redundant_expressions(x):
+            # type: (int) -> Tuple[bool, bool]
+            return x == 1 and True, x == 1 or False
+
+        self.run_pass('peephole', redundant_expressions.graph)
+        self.assertEqual(redundant_expressions(1), (True, True))
+        self.assertEqual(redundant_expressions(0), (False, False))
+        # and True / or False are removed from graph
+        FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
+
     def test_trace_records_names(self):
         def foo(bar, baz):
             baz = bar + 3
index f7cd332..57768c0 100644 (file)
@@ -292,7 +292,6 @@ Gradient getGradient(const Node* n) {
   grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
   return grad;
 }
-
 } // anonymous namespace
 
 RegisterOperators reg_graph_executor_ops(
@@ -308,7 +307,6 @@ GraphExecutor* getGradExecutor(Operation& op) {
   }
   return nullptr;
 }
-
 } // namespace detail
 
 // a Graph can be created via tracing, or via a language-based frontend
@@ -505,6 +503,7 @@ struct GraphExecutorImpl {
     ConstantPooling(graph);
 
     PeepholeOptimize(graph);
+    ConstantPropagation(graph);
 
     // Unroll small loops, and eliminate expressions that are the same at every
     // iteration.
@@ -644,6 +643,5 @@ void runRequiredPasses(const std::shared_ptr<Graph>& g) {
   CanonicalizeOps(g);
   EliminateDeadCode(g);
 }
-
 } // namespace jit
 } // namespace torch
index b168b9d..3a8e292 100644 (file)
@@ -641,6 +641,10 @@ std::shared_ptr<Graph> Graph::copy() {
 bool Value::mustBeNone() const {
   return node_->mustBeNone();
 }
+bool Value::mustNotBeNone() const {
+  return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
+      !type()->cast<OptionalType>();
+}
 
 std::string Value::uniqueNameBase() const {
   std::string name = uniqueName();
@@ -771,9 +775,10 @@ bool Node::matches(
 }
 
 bool Node::mustBeNone() const {
-  return kind_ == prim::Constant && !this->hasAttributes() &&
-      (output()->type()->cast<OptionalType>() ||
-       output()->type() == NoneType::get());
+  return kind_ == prim::AutogradZero ||
+      (kind_ == prim::Constant && !this->hasAttributes() &&
+       (output()->type()->cast<OptionalType>() ||
+        output()->type() == NoneType::get()));
 }
 
 void Node::dump() const {
index 4ef6883..afcadb6 100644 (file)
@@ -171,6 +171,7 @@ struct Value {
     return type()->kind() == TypeKind::CompleteTensorType;
   }
   TORCH_API bool mustBeNone() const;
+  TORCH_API bool mustNotBeNone() const;
   size_t unique() const {
     return unique_;
   }
index 5b7652b..ffe6031 100644 (file)
@@ -5,6 +5,7 @@
 #include <torch/csrc/jit/constants.h>
 #include <torch/csrc/jit/interpreter.h>
 #include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/node_hashing.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/passes/alias_analysis.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
@@ -119,22 +120,41 @@ void inlineIf(Node* n, const AliasDb& aliasDb) {
   inlineIfBody(n->blocks().at(block_index));
 }
 
+void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) {
+  n->outputs().at(i)->replaceAllUsesWith(replacement);
+  n->eraseOutput(i);
+  n->blocks().at(0)->eraseOutput(i);
+  n->blocks().at(1)->eraseOutput(i);
+}
+
 // remove extra outputs from the node
 bool removeExtraIfOutputs(Node* n) {
   AT_CHECK(n->kind() == prim::If, "Only supported for If nodes");
   auto true_block = n->blocks()[0];
   auto false_block = n->blocks()[1];
+  auto graph = n->owningGraph();
   auto initial_outputs = true_block->outputs().size();
+  WithInsertPoint guard(n);
   for (size_t i = 0; i < true_block->outputs().size();) {
+    auto t_out = true_block->outputs().at(i);
+    auto f_out = false_block->outputs().at(i);
+
     // neither block changes the output value
     if (true_block->outputs()[i] == false_block->outputs()[i]) {
-      n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
-      n->eraseOutput(i);
-      true_block->eraseOutput(i);
-      false_block->eraseOutput(i);
-    } else {
-      i++; // increment bc we didn't remove current index
+      replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]);
+      continue;
+    }
+
+    // true block output is constant and constant matches false block output
+    auto maybe_const = toIValue(t_out);
+    auto eq = EqualNode();
+    if (maybe_const && eq(t_out->node(), f_out->node())) {
+      auto new_const = graph->insertConstant(*maybe_const, t_out->type());
+      replaceAndRemoveIfOutput(n, i, new_const);
+      continue;
     }
+
+    i++; // increment bc we didn't remove current index
   }
   // an output was removed
   return initial_outputs != true_block->outputs().size();
@@ -213,6 +233,5 @@ void ConstantPropagation(std::shared_ptr<Graph>& graph) {
   ConstantPropagation(graph->block(), aliasDb);
   EliminateDeadCode(graph);
 }
-
 } // namespace jit
 } // namespace torch
index 02b3993..5a55b97 100644 (file)
@@ -1,5 +1,5 @@
 #include <torch/csrc/jit/passes/peephole.h>
-
+#include <torch/csrc/jit/ir_views.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
@@ -165,6 +165,49 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
           u.user->replaceInput(0, node->inputs().at(0));
         }
       }
+    } else if (node->kind() == prim::If) {
+      IfView n(node);
+      // this handles redundant short circuits like "x and True" or "x or False"
+      for (size_t i = 0; i < n.outputs().size(); ++i) {
+        if (n.outputs().at(i)->type() != BoolType::get()) {
+          continue;
+        }
+        bool true_val =
+            constant_as<bool>(n.thenOutputs().at(i)).value_or(false);
+        bool false_val =
+            constant_as<bool>(n.elseOutputs().at(i)).value_or(true);
+        // if an if node's output equals its condition replace output with
+        // condition
+        if (true_val && !false_val) {
+          n.outputs().at(i)->replaceAllUsesWith(n.cond());
+        }
+      }
+    } else if (
+        node->kind() == aten::__is__ || node->kind() == aten::__isnot__) {
+      // if we are comparing a None value with a value that can't be None
+      // replace the output with true if node is __isnot__ or false if node is
+      // __is__
+      AT_ASSERT(node->inputs().size() == 2);
+      for (size_t check_none_index : {0, 1}) {
+        bool input_must_be_none =
+            node->inputs().at(check_none_index)->mustBeNone();
+        bool other_must_not_be_none =
+            node->inputs().at(1 - check_none_index)->mustNotBeNone();
+        if (input_must_be_none && other_must_not_be_none) {
+          WithInsertPoint guard(node);
+          auto output = node->owningGraph()->insertConstant(
+              node->kind() == aten::__isnot__);
+          node->output()->replaceAllUsesWith(output);
+        }
+      }
+    } else if (
+        node->kind() == prim::unchecked_unwrap_optional ||
+        node->kind() == aten::_unwrap_optional) {
+      // we are unwrapping an input that can't be None, remove the unwrap
+      auto input = node->input();
+      if (input->mustNotBeNone()) {
+        node->output()->replaceAllUsesWith(node->input());
+      }
     }
   }
 }
@@ -180,6 +223,5 @@ void PeepholeOptimize(
     bool addmm_fusion_enabled) {
   PeepholeOptimize(graph->block(), addmm_fusion_enabled);
 }
-
 } // namespace jit
 } // namespace torch
index c4796f5..1a06161 100644 (file)
@@ -1039,21 +1039,32 @@ struct to_ir {
     const auto first_bool_info = findRefinements(first_expr);
     Value* first_value = emitCond(Expr(first_expr));
 
+    // if the second expr in the short circuit is not evaluated,
+    // than the first expression is False if the short circuit
+    // is an `and` and True if the short circuit is an `or`.
+    // `False and expr` -> False, `True or expr` -> True
+    //
+    // inserting it as a constant makes optimization easier
+
+    Value* first_value_returned;
+
     const Refinements* first_expr_refinements;
     const Refinements* second_expr_refinements;
     // if it's an OR the first expr is emitted in the true branch
     // and the second expr in the false branch, if it's an AND the opposite
     if (is_or) {
+      first_value_returned = graph->insertConstant(true, nullptr, loc);
       first_expr_refinements = &first_bool_info.true_refinements_;
       second_expr_refinements = &first_bool_info.false_refinements_;
     } else {
+      first_value_returned = graph->insertConstant(false, nullptr, loc);
       first_expr_refinements = &first_bool_info.false_refinements_;
       second_expr_refinements = &first_bool_info.true_refinements_;
     }
 
     auto get_first_expr = [&] {
       insertRefinements(*first_expr_refinements);
-      return first_value;
+      return first_value_returned;
     };
 
     auto get_second_expr = [&] {
@@ -2094,7 +2105,6 @@ struct to_ir {
       }
       return classNew->createObject(
           apply.range(), method, Var(apply.inputs()[0]).name().name());
-      ;
     } else {
       auto inputs = getNamedValues(apply.inputs(), true);
       auto attributes = emitAttributes(apply.attributes());
index 19e9ccb..14449e8 100644 (file)
@@ -21,9 +21,14 @@ TypeAndAlias SchemaTypeParser::parseBaseType() {
       {"float", FloatType::get()},
       {"int", IntType::get()},
       {"bool", BoolType::get()},
+      {"None", NoneType::get()},
   };
-  auto tok = L.expect(TK_IDENT);
-  auto text = tok.text();
+  auto tok = L.cur();
+  if (!L.nextIf(TK_NONE)) {
+    L.expect(TK_IDENT);
+  }
+  std::string text = tok.text();
+
   auto it = type_map.find(text);
   if (it == type_map.end()) {
     if (text.size() > 0 && islower(text[0])) {