#include <test/cpp/jit/test_ivalue.h>
#include <test/cpp/jit/test_misc.h>
#include <test/cpp/jit/test_netdef_converter.h>
+#include <test/cpp/jit/test_peephole_optimize.h>
#include <test/cpp/jit/test_subgraph_utils.h>
using namespace torch::jit::script;
_(THNNConv) \
_(ATenNativeBatchNorm) \
_(NoneSchemaMatch) \
- _(ClassParser)
+ _(ClassParser) \
+ _(PeepholeOptimize)
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
--- /dev/null
+#pragma once
+
+#include <test/cpp/jit/test_base.h>
+#include <test/cpp/jit/test_utils.h>
+
+#include <torch/csrc/jit/irparser.h>
+#include <torch/csrc/jit/passes/peephole.h>
+
+namespace torch {
+namespace jit {
+
+using namespace script;
+using namespace testing;
+
+namespace test {
+
+void testPeepholeOptimize() {
+ // test is / is not none optimization
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%0 : int):
+ %1 : None = prim::Constant()
+ %2 : bool = aten::__is__(%0, %1)
+ %3 : bool = aten::__isnot__(%0, %1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck()
+ .check_not("aten::__is__")
+ ->check_not("aten::__isnot__")
+ ->run(*graph);
+ }
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%0: int?):
+ %1 : None = prim::Constant()
+ %2 : bool = aten::__is__(%0, %1)
+ %3 : bool = aten::__isnot__(%0, %1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck()
+ .check("aten::__is__")
+ ->check("aten::__isnot__")
+ ->run(*graph);
+ }
+
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%0: int?):
+ %1 : Tensor = prim::AutogradZero()
+ %2 : None = prim::Constant()
+ %4 : bool = aten::__is__(%0, %1)
+ %5 : bool = aten::__isnot__(%1, %2)
+ return (%4, %5)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck()
+ .check("aten::__is__")
+ ->check_not("aten::__isnot__")
+ ->run(*graph);
+ }
+
+ // test unwrap optional
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph():
+ %1 : Float(*, *, *) = prim::Constant()
+ %2 : bool = aten::_unwrap_optional(%1)
+ %3 : bool = prim::unchecked_unwrap_optional(%1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck().check_not("unwrap")->run(*graph);
+ }
+ {
+ auto graph = std::make_shared<Graph>();
+ parseIR(
+ R"IR(
+graph(%1 : Float(*, *, *)?):
+ %2 : bool = aten::_unwrap_optional(%1)
+ %3 : bool = prim::unchecked_unwrap_optional(%1)
+ return (%2, %3)
+ )IR",
+ graph.get());
+ PeepholeOptimize(graph);
+ testing::FileCheck().check_count("unwrap", 2)->run(*graph);
+ }
+}
+} // namespace test
+} // namespace jit
+} // namespace torch
# testing that 1 // 0 error is not thrownn
self.run_pass('constant_propagation', constant_prop.graph)
+ def test_short_circuit_optimization(self):
+ @torch.jit.script
+ def const_expressions(x):
+ # type: (int) -> Tuple[bool, bool]
+ return x == 1 and False, x == 1 or True
+ self.run_pass('constant_propagation', const_expressions.graph)
+ FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
+ self.assertEqual(const_expressions(1), (False, True))
+
+ @torch.jit.script
+ def redundant_expressions(x):
+ # type: (int) -> Tuple[bool, bool]
+ return x == 1 and True, x == 1 or False
+
+ self.run_pass('peephole', redundant_expressions.graph)
+ self.assertEqual(redundant_expressions(1), (True, True))
+ self.assertEqual(redundant_expressions(0), (False, False))
+ # and True / or False are removed from graph
+ FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
+
def test_trace_records_names(self):
def foo(bar, baz):
baz = bar + 3
grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
return grad;
}
-
} // anonymous namespace
RegisterOperators reg_graph_executor_ops(
}
return nullptr;
}
-
} // namespace detail
// a Graph can be created via tracing, or via a language-based frontend
ConstantPooling(graph);
PeepholeOptimize(graph);
+ ConstantPropagation(graph);
// Unroll small loops, and eliminate expressions that are the same at every
// iteration.
CanonicalizeOps(g);
EliminateDeadCode(g);
}
-
} // namespace jit
} // namespace torch
bool Value::mustBeNone() const {
return node_->mustBeNone();
}
+bool Value::mustNotBeNone() const {
+ return node_->kind() != prim::AutogradAdd && type() != NoneType::get() &&
+ !type()->cast<OptionalType>();
+}
std::string Value::uniqueNameBase() const {
std::string name = uniqueName();
}
bool Node::mustBeNone() const {
- return kind_ == prim::Constant && !this->hasAttributes() &&
- (output()->type()->cast<OptionalType>() ||
- output()->type() == NoneType::get());
+ return kind_ == prim::AutogradZero ||
+ (kind_ == prim::Constant && !this->hasAttributes() &&
+ (output()->type()->cast<OptionalType>() ||
+ output()->type() == NoneType::get()));
}
void Node::dump() const {
return type()->kind() == TypeKind::CompleteTensorType;
}
TORCH_API bool mustBeNone() const;
+ TORCH_API bool mustNotBeNone() const;
size_t unique() const {
return unique_;
}
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/interpreter.h>
#include <torch/csrc/jit/ir.h>
+#include <torch/csrc/jit/node_hashing.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
inlineIfBody(n->blocks().at(block_index));
}
+void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) {
+ n->outputs().at(i)->replaceAllUsesWith(replacement);
+ n->eraseOutput(i);
+ n->blocks().at(0)->eraseOutput(i);
+ n->blocks().at(1)->eraseOutput(i);
+}
+
// remove extra outputs from the node
bool removeExtraIfOutputs(Node* n) {
AT_CHECK(n->kind() == prim::If, "Only supported for If nodes");
auto true_block = n->blocks()[0];
auto false_block = n->blocks()[1];
+ auto graph = n->owningGraph();
auto initial_outputs = true_block->outputs().size();
+ WithInsertPoint guard(n);
for (size_t i = 0; i < true_block->outputs().size();) {
+ auto t_out = true_block->outputs().at(i);
+ auto f_out = false_block->outputs().at(i);
+
// neither block changes the output value
if (true_block->outputs()[i] == false_block->outputs()[i]) {
- n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
- n->eraseOutput(i);
- true_block->eraseOutput(i);
- false_block->eraseOutput(i);
- } else {
- i++; // increment bc we didn't remove current index
+ replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]);
+ continue;
+ }
+
+ // true block output is constant and constant matches false block output
+ auto maybe_const = toIValue(t_out);
+ auto eq = EqualNode();
+ if (maybe_const && eq(t_out->node(), f_out->node())) {
+ auto new_const = graph->insertConstant(*maybe_const, t_out->type());
+ replaceAndRemoveIfOutput(n, i, new_const);
+ continue;
}
+
+ i++; // increment bc we didn't remove current index
}
// an output was removed
return initial_outputs != true_block->outputs().size();
ConstantPropagation(graph->block(), aliasDb);
EliminateDeadCode(graph);
}
-
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/passes/peephole.h>
-
+#include <torch/csrc/jit/ir_views.h>
#include <torch/csrc/jit/symbolic_variable.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
u.user->replaceInput(0, node->inputs().at(0));
}
}
+ } else if (node->kind() == prim::If) {
+ IfView n(node);
+ // this handles redundant short circuits like "x and True" or "x or False"
+ for (size_t i = 0; i < n.outputs().size(); ++i) {
+ if (n.outputs().at(i)->type() != BoolType::get()) {
+ continue;
+ }
+ bool true_val =
+ constant_as<bool>(n.thenOutputs().at(i)).value_or(false);
+ bool false_val =
+ constant_as<bool>(n.elseOutputs().at(i)).value_or(true);
+ // if an if node's output equals its condition replace output with
+ // condition
+ if (true_val && !false_val) {
+ n.outputs().at(i)->replaceAllUsesWith(n.cond());
+ }
+ }
+ } else if (
+ node->kind() == aten::__is__ || node->kind() == aten::__isnot__) {
+ // if we are comparing a None value with a value that can't be None
+ // replace the output with true if node is __isnot__ or false if node is
+ // __is__
+ AT_ASSERT(node->inputs().size() == 2);
+ for (size_t check_none_index : {0, 1}) {
+ bool input_must_be_none =
+ node->inputs().at(check_none_index)->mustBeNone();
+ bool other_must_not_be_none =
+ node->inputs().at(1 - check_none_index)->mustNotBeNone();
+ if (input_must_be_none && other_must_not_be_none) {
+ WithInsertPoint guard(node);
+ auto output = node->owningGraph()->insertConstant(
+ node->kind() == aten::__isnot__);
+ node->output()->replaceAllUsesWith(output);
+ }
+ }
+ } else if (
+ node->kind() == prim::unchecked_unwrap_optional ||
+ node->kind() == aten::_unwrap_optional) {
+ // we are unwrapping an input that can't be None, remove the unwrap
+ auto input = node->input();
+ if (input->mustNotBeNone()) {
+ node->output()->replaceAllUsesWith(node->input());
+ }
}
}
}
bool addmm_fusion_enabled) {
PeepholeOptimize(graph->block(), addmm_fusion_enabled);
}
-
} // namespace jit
} // namespace torch
const auto first_bool_info = findRefinements(first_expr);
Value* first_value = emitCond(Expr(first_expr));
+ // if the second expr in the short circuit is not evaluated,
+ // than the first expression is False if the short circuit
+ // is an `and` and True if the short circuit is an `or`.
+ // `False and expr` -> False, `True or expr` -> True
+ //
+ // inserting it as a constant makes optimization easier
+
+ Value* first_value_returned;
+
const Refinements* first_expr_refinements;
const Refinements* second_expr_refinements;
// if it's an OR the first expr is emitted in the true branch
// and the second expr in the false branch, if it's an AND the opposite
if (is_or) {
+ first_value_returned = graph->insertConstant(true, nullptr, loc);
first_expr_refinements = &first_bool_info.true_refinements_;
second_expr_refinements = &first_bool_info.false_refinements_;
} else {
+ first_value_returned = graph->insertConstant(false, nullptr, loc);
first_expr_refinements = &first_bool_info.false_refinements_;
second_expr_refinements = &first_bool_info.true_refinements_;
}
auto get_first_expr = [&] {
insertRefinements(*first_expr_refinements);
- return first_value;
+ return first_value_returned;
};
auto get_second_expr = [&] {
}
return classNew->createObject(
apply.range(), method, Var(apply.inputs()[0]).name().name());
- ;
} else {
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = emitAttributes(apply.attributes());
{"float", FloatType::get()},
{"int", IntType::get()},
{"bool", BoolType::get()},
+ {"None", NoneType::get()},
};
- auto tok = L.expect(TK_IDENT);
- auto text = tok.text();
+ auto tok = L.cur();
+ if (!L.nextIf(TK_NONE)) {
+ L.expect(TK_IDENT);
+ }
+ std::string text = tok.text();
+
auto it = type_map.find(text);
if (it == type_map.end()) {
if (text.size() > 0 && islower(text[0])) {