From dc6b5b2a52d182019f6f3f5f21b58e2fb592d993 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 25 Mar 2019 21:48:11 -0700 Subject: [PATCH] Optimize boolean expressions & unwraps (#18259) Summary: Simplify or eliminate boolean and/or expressions, optimize unwrapping a value that cannot be None, and optimize using `is` with a None and a non-None value Since peephole optimize is now introducing constants, i added another constant propagation pass after running it. Previously i had a PR that did this & optimized shape ops - i will add the shape optimizations in a separate PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18259 Differential Revision: D14602749 Pulled By: eellison fbshipit-source-id: 1c3f5a67067d8dfdf55d7b78dcb616472ea8a267 --- test/cpp/jit/test.cpp | 4 +- test/cpp/jit/test_peephole_optimize.h | 104 +++++++++++++++++++++++++ test/test_jit.py | 20 +++++ torch/csrc/jit/graph_executor.cpp | 4 +- torch/csrc/jit/ir.cpp | 11 ++- torch/csrc/jit/ir.h | 1 + torch/csrc/jit/passes/constant_propagation.cpp | 33 ++++++-- torch/csrc/jit/passes/peephole.cpp | 46 ++++++++++- torch/csrc/jit/script/compiler.cpp | 14 +++- torch/csrc/jit/script/schema_type_parser.cpp | 9 ++- 10 files changed, 226 insertions(+), 20 deletions(-) create mode 100644 test/cpp/jit/test_peephole_optimize.h diff --git a/test/cpp/jit/test.cpp b/test/cpp/jit/test.cpp index 1abb5ed..1c4823d 100644 --- a/test/cpp/jit/test.cpp +++ b/test/cpp/jit/test.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include using namespace torch::jit::script; @@ -61,7 +62,8 @@ namespace jit { _(THNNConv) \ _(ATenNativeBatchNorm) \ _(NoneSchemaMatch) \ - _(ClassParser) + _(ClassParser) \ + _(PeepholeOptimize) #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/test/cpp/jit/test_peephole_optimize.h b/test/cpp/jit/test_peephole_optimize.h new file mode 100644 index 0000000..32aacf4 --- /dev/null +++ b/test/cpp/jit/test_peephole_optimize.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include + +#include +#include + +namespace torch { +namespace jit { + +using namespace script; +using namespace testing; + +namespace test { + +void testPeepholeOptimize() { + // test is / is not none optimization + { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(%0 : int): + %1 : None = prim::Constant() + %2 : bool = aten::__is__(%0, %1) + %3 : bool = aten::__isnot__(%0, %1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check_not("aten::__is__") + ->check_not("aten::__isnot__") + ->run(*graph); + } + { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(%0: int?): + %1 : None = prim::Constant() + %2 : bool = aten::__is__(%0, %1) + %3 : bool = aten::__isnot__(%0, %1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check("aten::__is__") + ->check("aten::__isnot__") + ->run(*graph); + } + + { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(%0: int?): + %1 : Tensor = prim::AutogradZero() + %2 : None = prim::Constant() + %4 : bool = aten::__is__(%0, %1) + %5 : bool = aten::__isnot__(%1, %2) + return (%4, %5) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck() + .check("aten::__is__") + ->check_not("aten::__isnot__") + ->run(*graph); + } + + // test unwrap optional + { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(): + %1 : Float(*, *, *) = prim::Constant() + %2 : bool = aten::_unwrap_optional(%1) + %3 : bool = prim::unchecked_unwrap_optional(%1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck().check_not("unwrap")->run(*graph); + } + { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(%1 : Float(*, *, *)?): + %2 : bool = aten::_unwrap_optional(%1) + %3 : bool = prim::unchecked_unwrap_optional(%1) + return (%2, %3) + )IR", + graph.get()); + PeepholeOptimize(graph); + testing::FileCheck().check_count("unwrap", 2)->run(*graph); + } +} +} // namespace test +} // namespace jit +} // namespace torch diff --git a/test/test_jit.py b/test/test_jit.py index c2318e1..6ef2c1f 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1881,6 +1881,26 @@ class TestJit(JitTestCase): # testing that 1 // 0 error is not thrownn self.run_pass('constant_propagation', constant_prop.graph) + def test_short_circuit_optimization(self): + @torch.jit.script + def const_expressions(x): + # type: (int) -> Tuple[bool, bool] + return x == 1 and False, x == 1 or True + self.run_pass('constant_propagation', const_expressions.graph) + FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph) + self.assertEqual(const_expressions(1), (False, True)) + + @torch.jit.script + def redundant_expressions(x): + # type: (int) -> Tuple[bool, bool] + return x == 1 and True, x == 1 or False + + self.run_pass('peephole', redundant_expressions.graph) + self.assertEqual(redundant_expressions(1), (True, True)) + self.assertEqual(redundant_expressions(0), (False, False)) + # and True / or False are removed from graph + FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph) + def test_trace_records_names(self): def foo(bar, baz): baz = bar + 3 diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index f7cd332..57768c0 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -292,7 +292,6 @@ Gradient getGradient(const Node* n) { grad.df_output_vjps = fmap(n->is(attr::df_output_vjps)); return grad; } - } // anonymous namespace RegisterOperators reg_graph_executor_ops( @@ -308,7 +307,6 @@ GraphExecutor* getGradExecutor(Operation& op) { } return nullptr; } - } // namespace detail // a Graph can be created via tracing, or via a language-based frontend @@ -505,6 +503,7 @@ struct GraphExecutorImpl { ConstantPooling(graph); PeepholeOptimize(graph); + ConstantPropagation(graph); // Unroll small loops, and eliminate expressions that are the same at every // iteration. @@ -644,6 +643,5 @@ void runRequiredPasses(const std::shared_ptr& g) { CanonicalizeOps(g); EliminateDeadCode(g); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index b168b9d..3a8e292 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -641,6 +641,10 @@ std::shared_ptr Graph::copy() { bool Value::mustBeNone() const { return node_->mustBeNone(); } +bool Value::mustNotBeNone() const { + return node_->kind() != prim::AutogradAdd && type() != NoneType::get() && + !type()->cast(); +} std::string Value::uniqueNameBase() const { std::string name = uniqueName(); @@ -771,9 +775,10 @@ bool Node::matches( } bool Node::mustBeNone() const { - return kind_ == prim::Constant && !this->hasAttributes() && - (output()->type()->cast() || - output()->type() == NoneType::get()); + return kind_ == prim::AutogradZero || + (kind_ == prim::Constant && !this->hasAttributes() && + (output()->type()->cast() || + output()->type() == NoneType::get())); } void Node::dump() const { diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 4ef6883..afcadb6 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -171,6 +171,7 @@ struct Value { return type()->kind() == TypeKind::CompleteTensorType; } TORCH_API bool mustBeNone() const; + TORCH_API bool mustNotBeNone() const; size_t unique() const { return unique_; } diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 5b7652b..ffe6031 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -119,22 +120,41 @@ void inlineIf(Node* n, const AliasDb& aliasDb) { inlineIfBody(n->blocks().at(block_index)); } +void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) { + n->outputs().at(i)->replaceAllUsesWith(replacement); + n->eraseOutput(i); + n->blocks().at(0)->eraseOutput(i); + n->blocks().at(1)->eraseOutput(i); +} + // remove extra outputs from the node bool removeExtraIfOutputs(Node* n) { AT_CHECK(n->kind() == prim::If, "Only supported for If nodes"); auto true_block = n->blocks()[0]; auto false_block = n->blocks()[1]; + auto graph = n->owningGraph(); auto initial_outputs = true_block->outputs().size(); + WithInsertPoint guard(n); for (size_t i = 0; i < true_block->outputs().size();) { + auto t_out = true_block->outputs().at(i); + auto f_out = false_block->outputs().at(i); + // neither block changes the output value if (true_block->outputs()[i] == false_block->outputs()[i]) { - n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]); - n->eraseOutput(i); - true_block->eraseOutput(i); - false_block->eraseOutput(i); - } else { - i++; // increment bc we didn't remove current index + replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]); + continue; + } + + // true block output is constant and constant matches false block output + auto maybe_const = toIValue(t_out); + auto eq = EqualNode(); + if (maybe_const && eq(t_out->node(), f_out->node())) { + auto new_const = graph->insertConstant(*maybe_const, t_out->type()); + replaceAndRemoveIfOutput(n, i, new_const); + continue; } + + i++; // increment bc we didn't remove current index } // an output was removed return initial_outputs != true_block->outputs().size(); @@ -213,6 +233,5 @@ void ConstantPropagation(std::shared_ptr& graph) { ConstantPropagation(graph->block(), aliasDb); EliminateDeadCode(graph); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 02b3993..5a55b97 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -1,5 +1,5 @@ #include - +#include #include #include @@ -165,6 +165,49 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { u.user->replaceInput(0, node->inputs().at(0)); } } + } else if (node->kind() == prim::If) { + IfView n(node); + // this handles redundant short circuits like "x and True" or "x or False" + for (size_t i = 0; i < n.outputs().size(); ++i) { + if (n.outputs().at(i)->type() != BoolType::get()) { + continue; + } + bool true_val = + constant_as(n.thenOutputs().at(i)).value_or(false); + bool false_val = + constant_as(n.elseOutputs().at(i)).value_or(true); + // if an if node's output equals its condition replace output with + // condition + if (true_val && !false_val) { + n.outputs().at(i)->replaceAllUsesWith(n.cond()); + } + } + } else if ( + node->kind() == aten::__is__ || node->kind() == aten::__isnot__) { + // if we are comparing a None value with a value that can't be None + // replace the output with true if node is __isnot__ or false if node is + // __is__ + AT_ASSERT(node->inputs().size() == 2); + for (size_t check_none_index : {0, 1}) { + bool input_must_be_none = + node->inputs().at(check_none_index)->mustBeNone(); + bool other_must_not_be_none = + node->inputs().at(1 - check_none_index)->mustNotBeNone(); + if (input_must_be_none && other_must_not_be_none) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant( + node->kind() == aten::__isnot__); + node->output()->replaceAllUsesWith(output); + } + } + } else if ( + node->kind() == prim::unchecked_unwrap_optional || + node->kind() == aten::_unwrap_optional) { + // we are unwrapping an input that can't be None, remove the unwrap + auto input = node->input(); + if (input->mustNotBeNone()) { + node->output()->replaceAllUsesWith(node->input()); + } } } } @@ -180,6 +223,5 @@ void PeepholeOptimize( bool addmm_fusion_enabled) { PeepholeOptimize(graph->block(), addmm_fusion_enabled); } - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index c4796f5..1a06161 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1039,21 +1039,32 @@ struct to_ir { const auto first_bool_info = findRefinements(first_expr); Value* first_value = emitCond(Expr(first_expr)); + // if the second expr in the short circuit is not evaluated, + // than the first expression is False if the short circuit + // is an `and` and True if the short circuit is an `or`. + // `False and expr` -> False, `True or expr` -> True + // + // inserting it as a constant makes optimization easier + + Value* first_value_returned; + const Refinements* first_expr_refinements; const Refinements* second_expr_refinements; // if it's an OR the first expr is emitted in the true branch // and the second expr in the false branch, if it's an AND the opposite if (is_or) { + first_value_returned = graph->insertConstant(true, nullptr, loc); first_expr_refinements = &first_bool_info.true_refinements_; second_expr_refinements = &first_bool_info.false_refinements_; } else { + first_value_returned = graph->insertConstant(false, nullptr, loc); first_expr_refinements = &first_bool_info.false_refinements_; second_expr_refinements = &first_bool_info.true_refinements_; } auto get_first_expr = [&] { insertRefinements(*first_expr_refinements); - return first_value; + return first_value_returned; }; auto get_second_expr = [&] { @@ -2094,7 +2105,6 @@ struct to_ir { } return classNew->createObject( apply.range(), method, Var(apply.inputs()[0]).name().name()); - ; } else { auto inputs = getNamedValues(apply.inputs(), true); auto attributes = emitAttributes(apply.attributes()); diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index 19e9ccb..14449e8 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -21,9 +21,14 @@ TypeAndAlias SchemaTypeParser::parseBaseType() { {"float", FloatType::get()}, {"int", IntType::get()}, {"bool", BoolType::get()}, + {"None", NoneType::get()}, }; - auto tok = L.expect(TK_IDENT); - auto text = tok.text(); + auto tok = L.cur(); + if (!L.nextIf(TK_NONE)) { + L.expect(TK_IDENT); + } + std::string text = tok.text(); + auto it = type_map.find(text); if (it == type_map.end()) { if (text.size() > 0 && islower(text[0])) { -- 2.7.4