--- /dev/null
+
+import torch
+from torch.testing import FileCheck
+from torch.testing._internal.jit_utils import JitTestCase
+
+if __name__ == "__main__":
+ raise RuntimeError(
+ "This test file is not meant to be run directly, use:\n\n"
+ "\tpython test/test_jit.py TESTNAME\n\n"
+ "instead."
+ )
+
+
+class TestIfHoisting(JitTestCase):
+ def test_if_hoist_basic(self):
+ def fn(x: bool, y: int):
+ if x:
+ z = y + 3
+ else:
+ z = y + 3
+ return z
+
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+ FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+ self.assertEqual(fn(True, 1), fn_script(True, 1))
+
+ def test_if_hoist_transposed_expr(self):
+ """
+ Making sure that we can properly eliminate
+ an expression even if it is not at the start
+ of a block
+ """
+ def fn(x: bool, y: int):
+ if x:
+ a = y + 3
+ b = y * 2
+ else:
+ b = y * 2
+ a = y + 3
+ return a, b
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+
+ FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+
+ self.assertEqual(fn(True, 1), fn_script(True, 1))
+ self.assertEqual(fn(False, 5), fn_script(False, 5))
+
+ def test_if_hoist_swapped_expr(self):
+ """
+ Making sure that the if statement
+ doesn't get fully eliminated here
+ """
+ def fn(x: bool, y: int):
+ if x:
+ a = y + 3
+ b = y * 2
+ else:
+ a = y * 2
+ b = y + 3
+ return a, b
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+
+ FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+
+ self.assertEqual(fn(True, 1), fn_script(True, 1))
+ self.assertEqual(fn(False, 5), fn_script(False, 5))
+
+ def test_if_hoist_reused_var(self):
+ """
+ Making sure that cases where the python variable is reused
+ is handled correctly
+ """
+ def fn(x: bool, y: int):
+ b = 6
+ if x:
+ a = y + 3
+ a = y * 2
+ else:
+ a = y * 2
+ b = y + 3
+ return a, b
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+
+ FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::mul", 1, exactly=True).run(op_graph)
+
+ self.assertEqual(fn(True, 1), fn_script(True, 1))
+ self.assertEqual(fn(False, 5), fn_script(False, 5))
+
+ def test_no_hoist(self):
+ """
+ Nothing should happen here, expressions are different
+ """
+ def fn(x: bool, y: int, z: int):
+ if x:
+ a = y + 3
+ else:
+ a = z + 3
+ return a
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+
+ FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+
+ self.assertEqual(fn(True, 1, 3), fn_script(True, 1, 3))
+ self.assertEqual(fn(False, 5, 10), fn_script(False, 5, 10))
+
+ def test_mutate_before(self):
+ """
+ Make sure that if there is a mutation before the common
+ op, the hoist doesn't happen
+ """
+ def fn(x: bool, y: torch.Tensor):
+ if x:
+ y.add_(8)
+ a = y + 3
+ else:
+ a = y + 3
+ return a
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+
+ FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add_", 1, exactly=True).run(op_graph)
+
+ t1 = torch.Tensor([1])
+ t2 = torch.Tensor([5, 6])
+ self.assertEqual(fn(True, t1), fn_script(True, t1))
+ self.assertEqual(fn(False, t2), fn_script(False, t2))
+
+ def test_mutate_after(self):
+ """
+ Check that the hoist can happen properly, and
+ that the output is still correct.
+ """
+ def fn(x: bool, y: torch.Tensor):
+ if x:
+ b = 1
+ a = y + 3
+ y.add_(8)
+ else:
+ b = 2
+ a = y + 3
+ c = b + a
+ return a
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+
+ FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+
+ t1 = torch.Tensor([1])
+ t2 = torch.Tensor([5, 6])
+ self.assertEqual(fn(True, t1.clone()), fn_script(True, t1.clone()))
+ self.assertEqual(fn(False, t2.clone()), fn_script(False, t2.clone()))
+
+ def test_multiple_hoists(self):
+ """
+ test that hoists that depend on other hoists are done correctly
+ """
+ def fn(x: bool, y: torch.Tensor):
+ if x:
+ a = y + 3
+ b = a + y
+ else:
+ a = y + 3
+ b = a + y
+ c = b * 2
+ return c
+
+ fn_script = torch.jit.script(fn)
+ op_graph = fn_script.graph
+ self.run_pass("common_expression_hoisting", op_graph)
+ self.run_pass("dce", op_graph)
+
+ FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
+ FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+
+ t1 = torch.Tensor([1])
+ t2 = torch.Tensor([5, 6])
+ self.assertEqual(fn(True, t1), fn_script(True, t1))
+ self.assertEqual(fn(False, t2), fn_script(False, t2))
def __init__(self):
super(Res, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
+ self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
self.use_skip = True
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
if self.use_skip:
return self.conv(x)
else:
- return self.conv(x)
+ return self.conv2(x)
class M(torch.nn.Module):
def __init__(self):
from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
from jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401
from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis # noqa: F401
+from jit.test_if_hoisting import TestIfHoisting # noqa: F401
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401
from jit.test_peephole import TestPeephole # noqa: F401
"torch/csrc/jit/passes/clear_profiling.cpp",
"torch/csrc/jit/passes/clear_undefinedness.cpp",
"torch/csrc/jit/passes/common_subexpression_elimination.cpp",
+ "torch/csrc/jit/passes/common_expression_hoisting.cpp",
"torch/csrc/jit/passes/concat_opt.cpp",
"torch/csrc/jit/passes/constant_pooling.cpp",
"torch/csrc/jit/passes/constant_propagation.cpp",
def _jit_pass_constant_propagation(Graph) -> None: ...
def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
def _jit_erase_non_input_shape_information(Graph) -> None: ...
+def _jit_pass_common_expression_hoisting(Graph) -> None: ...
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
def _jit_can_fuse_on_cpu() -> _bool: ...
} // anonymous namespace
+// Makes a hash that hashes the input Value, the output type
+// as well as the node attributes
size_t HashNode::operator()(const Node* k) const {
AT_ASSERT(k != nullptr);
size_t constant_hash = 0;
constant_hash);
};
+// Checks that two nodes have the same inputs, output types
+// and node attributes.
bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
if (lhs == nullptr && rhs == nullptr)
return true;
if (!attributesEqualCSE(lhs, rhs))
return false;
+ // Check if the blocks contained in a op are the same
+ if (lhs->blocks().size() != rhs->blocks().size()) {
+ return false;
+ }
+ for (size_t i = 0; i < lhs->blocks().size(); ++i) {
+ if (lhs->blocks()[i] != rhs->blocks()[i]) {
+ return false;
+ }
+ }
+
return true;
};
--- /dev/null
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
+
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/node_hashing.h>
+#include <torch/csrc/jit/jit_log.h>
+
+#include <cstddef>
+#include <unordered_set>
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace {
+
+struct CommonExpressionHoister {
+ CommonExpressionHoister(std::shared_ptr<Graph> graph)
+ : graph_(std::move(graph)) {}
+
+ bool run() {
+ HoistCommonExpression(graph_->block());
+ return changed_;
+ }
+
+ void HoistFromIfNode(Node* if_node) {
+ Block* true_block = if_node->blocks()[0];
+ Block* false_block = if_node->blocks()[1];
+ // find common statements in the two subblocks
+
+ auto true_block_nodes = std::unordered_set<Node*, HashNode, EqualNode>(
+ true_block->nodes().begin(), true_block->nodes().end());
+ for (auto it = false_block->nodes().begin();
+ it != false_block->nodes().end();) {
+ Node* false_b_node = *it;
+ // node may be moved to a different block so advance iterator now
+ ++it;
+
+ auto matching_elem = true_block_nodes.find(false_b_node);
+ if (matching_elem == true_block_nodes.end()) {
+ continue;
+ }
+ Node* true_b_node = *matching_elem;
+
+ // Check if a move to the front of the block is valid
+ // If both of the moves are valid, then we know we can move the item out
+ // of the if blocks entirely.
+ AliasDb& aliasDb = getOrCreateAliasDb();
+ bool true_moveable = aliasDb.couldMoveAfterTopologically(
+ true_b_node, true_block->nodes().front());
+ bool false_moveable = aliasDb.couldMoveAfterTopologically(
+ false_b_node, false_block->nodes().front());
+
+ if (!true_moveable || !false_moveable) {
+ continue;
+ }
+
+ // Get all the uses of the output to delete and reinsert them
+ // as the input would change, the HashNode value would also change.
+ std::unordered_set<Node*> true_b_uses;
+ for (Value* true_out : true_b_node->outputs()) {
+ for (Use true_use : true_out->uses()) {
+ if (true_use.user->owningBlock() == true_block) {
+ // Make sure we are not accidentally adding stuff from subblocks
+ true_b_uses.insert(true_use.user);
+ }
+ }
+ }
+ for (Node* uses_node : true_b_uses) {
+ true_block_nodes.erase(uses_node);
+ }
+
+ // Now hoist the statement out of the block
+ changed_ = true;
+ false_b_node->moveBefore(if_node);
+
+ true_b_node->replaceAllUsesWith(false_b_node);
+
+ true_block_nodes.erase(true_b_node);
+ true_block_nodes.insert(true_b_uses.cbegin(), true_b_uses.cend());
+ true_b_node->destroy();
+ }
+ }
+
+ void EliminateUnnecessaryIfOutputs(Node* if_node) {
+ Block* true_block = if_node->blocks()[0];
+ Block* false_block = if_node->blocks()[1];
+
+ // fix up the if block outputs
+ for (size_t i = 0; i < true_block->outputs().size();) {
+ // Need to check both sides match to eliminate common if block outputs
+ Value* true_block_output = true_block->outputs().at(i);
+ Value* false_block_output = false_block->outputs().at(i);
+ if (true_block_output != false_block_output) {
+ i++;
+ continue;
+ }
+
+ // We have a matching output, and can remove it from the block itself
+ if_node->outputs().at(i)->replaceAllUsesWith(true_block_output);
+ if_node->eraseOutput(i);
+ true_block->eraseOutput(i);
+ false_block->eraseOutput(i);
+ changed_ = true;
+ }
+
+ // No need to test here if the IF block should be eliminated.
+ // The DCE pass will determine that for us.
+ }
+
+ void HoistCommonExpression(Block* block) {
+ for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+ Node* node = *it;
+ ++it;
+
+ for (auto sub_block : node->blocks()) {
+ HoistCommonExpression(sub_block);
+ }
+
+ if (node->kind() == prim::If) {
+ HoistFromIfNode(node);
+ EliminateUnnecessaryIfOutputs(node);
+ }
+ }
+ }
+
+ AliasDb& getOrCreateAliasDb() {
+ if (!alias_db_) {
+ alias_db_ = std::make_unique<AliasDb>(graph_);
+ }
+
+ return *alias_db_;
+ }
+
+ private:
+ std::unique_ptr<AliasDb> alias_db_;
+ std::shared_ptr<Graph> graph_;
+ bool changed_ = false;
+};
+} // anonymous namespace
+bool HoistCommonExpression(const std::shared_ptr<Graph>& graph) {
+ // This moves common subexpressions from the two sides of an
+ // if block out of the if block.
+
+ GRAPH_DUMP("Before CEH", graph);
+ CommonExpressionHoister ceh(graph);
+ bool changed = ceh.run();
+ if (changed) {
+ GRAPH_DUMP("After CEH Changes", graph);
+ }
+ return changed;
+}
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+
+TORCH_API bool HoistCommonExpression(const std::shared_ptr<Graph>& graph);
+}
+} // namespace torch
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
return EliminateCommonSubexpression(g); // overload resolution
})
.def(
+ "_jit_pass_common_expression_hoisting",
+ [](std::shared_ptr<Graph>& g) {
+ return HoistCommonExpression(g); // overload resolution
+ })
+ .def(
"_jit_pass_fuse_quantized_add_relu",
[](std::shared_ptr<Graph>& g) {
return FuseQuantizedAddRelu(g); // overload resolution
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/batch_mm.h>
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
"After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph);
EliminateCommonSubexpression(graph);
GRAPH_DEBUG(
- "After EliminateCommonSubexpression, before PeepholeOptimize\n", *graph);
+ "After EliminateCommonSubexpression , before PeepholeOptimize\n", *graph);
PeepholeOptimize(graph);
GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
EliminateCommonSubexpression(graph);
GRAPH_DEBUG(
- "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
-
+ "After EliminateCommonSubexpression, before HoistCommonExpression\n",
+ *graph);
+ HoistCommonExpression(graph);
+ GRAPH_DEBUG("After HoistCommonExpression, before CheckInplace\n", *graph);
CheckInplace(graph);
GRAPH_DEBUG("After CheckInplace (end of runOptimization)", *graph);
}
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
#include <torch/csrc/jit/passes/clear_profiling.h>
#include <torch/csrc/jit/passes/clear_undefinedness.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
EliminateCommonSubexpression(graph);
GRAPH_DEBUG(
- "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
-
+ "After EliminateCommonSubexpression, before HoistCommonExpression\n",
+ *graph);
+ HoistCommonExpression(graph);
+ GRAPH_DEBUG("After HoistCommonExpression, before CheckInplace\n", *graph);
CheckInplace(graph);
}
GRAPH_DEBUG(
"After CheckInplace (end of runPreAutodiffPassPipeline)\n", *graph);
}
-void runDiffGraphPasses(std::shared_ptr<Graph>& graph) {
- GRAPH_DEBUG(
- "Before EliminateDeadCode (beginning of runDiffGraphPasses)\n", *graph);
- // runOptimization:
- {
- // Basic graph preprocessing to eliminate noise.
- EliminateDeadCode(graph);
- GRAPH_DEBUG(
- "After EliminateDeadCode, before EliminateCommonSubexpression\n",
- *graph);
- EliminateCommonSubexpression(graph);
- GRAPH_DEBUG(
- "After EliminateCommonSubexpression, before PeepholeOptimize\n",
- *graph);
-
- PeepholeOptimize(graph);
- GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
- ConstantPropagation(graph);
- GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph);
- ConstantPooling(graph);
- GRAPH_DEBUG("After ConstantPooling, before UnrollLoops\n", *graph);
-
- UnrollLoops(graph);
- GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph);
- // run again with unrolled loops
- RemoveListMutation(graph);
- GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph);
- PeepholeOptimize(graph);
- GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
- ConstantPropagation(graph);
- GRAPH_DEBUG(
- "After ConstantPropagation, before EliminateCommonSubexpression\n",
- *graph);
-
- EliminateCommonSubexpression(graph);
- GRAPH_DEBUG(
- "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
-
- CheckInplace(graph);
- }
- GRAPH_DEBUG("After CheckInplace, before customPrePasses\n", *graph);
-
- // runNondiffOptimization
- {
- // Run custom passes that different backends can register.
- for (const auto& passPair : getCustomPrePasses()) {
- passPair.first(graph);
- }
- GRAPH_DEBUG("After customPrePasses, before LowerSimpleTuples\n", *graph);
-
- // TupleConstruct / TupleUnpack pairs can still be present at this point
- // and must be removed for fusion.
- LowerSimpleTuples(graph);
- GRAPH_DEBUG("After LowerSimpleTuples\n", *graph);
-
- if (tensorExprFuserEnabled()) {
- // Remove prim::profile nodes and embed the profile info directly in the
- // IR in value types. We're doing such transformation as optimizations
- // that try to merge/fuse nodes in the graph (e.g. BatchMM and GraphFuser)
- // work worse in the presence of intermittent prim::profile nodes.
- // Optimizations relying on the type info are also responsible for
- // inserting proper type checks. Once we're done with these optimizations
- // we will wipe the tensor type information from the IR, so that it's not
- // accidentally used by any other pass.
- RemoveProfileNodesAndSpecializeTypes(graph);
- GRAPH_DEBUG(
- "After RemoveProfileNodesAndSpecializeTypes, before BatchMM\n",
- *graph);
- // Rewrite subgraphs with many MMs into expressions that batch them.
- BatchMM(graph);
- GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
-
- FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1);
- GRAPH_DEBUG(
- "After Fusion, before RemoveTensorTypeSpecializations\n", *graph);
-
- // Wipe tensor type info from the IR
- RemoveTensorTypeSpecializations(graph);
- GRAPH_DEBUG(
- "After RemoveTensorTypeSpecializations, before customPostPasses\n",
- *graph);
- } else {
- // Rewrite subgraphs with many MMs into expressions that batch them.
- BatchMM(graph);
- GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
-
- FuseGraph(graph, true);
- GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph);
- }
-
- // Run custom post-fusion passes
- for (const auto& passPair : getCustomPostPasses()) {
- passPair.first(graph);
- }
- }
- GRAPH_DEBUG("After customPostPasses (end of runDiffGraphPasses)\n", *graph);
-}
-
void runNoGradOptimizations(std::shared_ptr<Graph>& graph) {
GRAPH_DEBUG(
"After customPostPasses (beginning of runNoGradOptimizations)\n", *graph);
DecomposeOps(graph);
GRAPH_DEBUG("After DecomposeOps, before ConstantPropagation\n", *graph);
ConstantPropagation(graph);
- GRAPH_DEBUG("After ConstantPropagation, before EliminateDeadCode\n", *graph);
+ GRAPH_DEBUG(
+ "After ConstantPropagation, before HoistCommonExpression\n", *graph);
+ HoistCommonExpression(graph);
+ GRAPH_DEBUG(
+ "After EliminateCommonSubexpression, before ElimiateDeadCode\n", *graph);
EliminateDeadCode(graph);
GRAPH_DEBUG(
"After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph);