Hoisting common expressions out of If blocks (#59492)
authorJohn Clow <jclow@fb.com>
Wed, 18 Aug 2021 23:28:02 +0000 (16:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 18 Aug 2021 23:29:30 +0000 (16:29 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59492

Adding code to find common expressions from the two subblocks of an if
operation and hoist them before the if block.
This also allows Dead Code Elimination to
then eliminate some if blocks.

Also eliminated some dead code in the codebase.

Test Plan:
python test_jit.py TestIfHoisting

Imported from OSS

Reviewed By: ngimel

Differential Revision: D29399533

fbshipit-source-id: 9336b9dc48c02c38862f98f98cd72fc1767a1802

12 files changed:
test/jit/test_if_hoisting.py [new file with mode: 0644]
test/quantization/jit/test_quantize_jit.py
test/test_jit.py
tools/build_variables.bzl
torch/_C/__init__.pyi.in
torch/csrc/jit/ir/node_hashing.cpp
torch/csrc/jit/passes/common_expression_hoisting.cpp [new file with mode: 0644]
torch/csrc/jit/passes/common_expression_hoisting.h [new file with mode: 0644]
torch/csrc/jit/passes/symbolic_shape_analysis.cpp
torch/csrc/jit/python/init.cpp
torch/csrc/jit/runtime/graph_executor.cpp
torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp

diff --git a/test/jit/test_if_hoisting.py b/test/jit/test_if_hoisting.py
new file mode 100644 (file)
index 0000000..c8fd4a4
--- /dev/null
@@ -0,0 +1,213 @@
+
+import torch
+from torch.testing import FileCheck
+from torch.testing._internal.jit_utils import JitTestCase
+
+if __name__ == "__main__":
+    raise RuntimeError(
+        "This test file is not meant to be run directly, use:\n\n"
+        "\tpython test/test_jit.py TESTNAME\n\n"
+        "instead."
+    )
+
+
+class TestIfHoisting(JitTestCase):
+    def test_if_hoist_basic(self):
+        def fn(x: bool, y: int):
+            if x:
+                z = y + 3
+            else:
+                z = y + 3
+            return z
+
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+        FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+        self.assertEqual(fn(True, 1), fn_script(True, 1))
+
+    def test_if_hoist_transposed_expr(self):
+        """
+        Making sure that we can properly eliminate
+        an expression even if it is not at the start
+        of a block
+        """
+        def fn(x: bool, y: int):
+            if x:
+                a = y + 3
+                b = y * 2
+            else:
+                b = y * 2
+                a = y + 3
+            return a, b
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+
+        FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+
+        self.assertEqual(fn(True, 1), fn_script(True, 1))
+        self.assertEqual(fn(False, 5), fn_script(False, 5))
+
+    def test_if_hoist_swapped_expr(self):
+        """
+        Making sure that the if statement
+        doesn't get fully eliminated here
+        """
+        def fn(x: bool, y: int):
+            if x:
+                a = y + 3
+                b = y * 2
+            else:
+                a = y * 2
+                b = y + 3
+            return a, b
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+
+        FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+
+        self.assertEqual(fn(True, 1), fn_script(True, 1))
+        self.assertEqual(fn(False, 5), fn_script(False, 5))
+
+    def test_if_hoist_reused_var(self):
+        """
+        Making sure that cases where the python variable is reused
+        is handled correctly
+        """
+        def fn(x: bool, y: int):
+            b = 6
+            if x:
+                a = y + 3
+                a = y * 2
+            else:
+                a = y * 2
+                b = y + 3
+            return a, b
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+
+        FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::mul", 1, exactly=True).run(op_graph)
+
+        self.assertEqual(fn(True, 1), fn_script(True, 1))
+        self.assertEqual(fn(False, 5), fn_script(False, 5))
+
+    def test_no_hoist(self):
+        """
+        Nothing should happen here, expressions are different
+        """
+        def fn(x: bool, y: int, z: int):
+            if x:
+                a = y + 3
+            else:
+                a = z + 3
+            return a
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+
+        FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+
+        self.assertEqual(fn(True, 1, 3), fn_script(True, 1, 3))
+        self.assertEqual(fn(False, 5, 10), fn_script(False, 5, 10))
+
+    def test_mutate_before(self):
+        """
+        Make sure that if there is a mutation before the common
+        op, the hoist doesn't happen
+        """
+        def fn(x: bool, y: torch.Tensor):
+            if x:
+                y.add_(8)
+                a = y + 3
+            else:
+                a = y + 3
+            return a
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+
+        FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add_", 1, exactly=True).run(op_graph)
+
+        t1 = torch.Tensor([1])
+        t2 = torch.Tensor([5, 6])
+        self.assertEqual(fn(True, t1), fn_script(True, t1))
+        self.assertEqual(fn(False, t2), fn_script(False, t2))
+
+    def test_mutate_after(self):
+        """
+        Check that the hoist can happen properly, and
+        that the output is still correct.
+        """
+        def fn(x: bool, y: torch.Tensor):
+            if x:
+                b = 1
+                a = y + 3
+                y.add_(8)
+            else:
+                b = 2
+                a = y + 3
+            c = b + a
+            return a
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+
+        FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+
+        t1 = torch.Tensor([1])
+        t2 = torch.Tensor([5, 6])
+        self.assertEqual(fn(True, t1.clone()), fn_script(True, t1.clone()))
+        self.assertEqual(fn(False, t2.clone()), fn_script(False, t2.clone()))
+
+    def test_multiple_hoists(self):
+        """
+        test that hoists that depend on other hoists are done correctly
+        """
+        def fn(x: bool, y: torch.Tensor):
+            if x:
+                a = y + 3
+                b = a + y
+            else:
+                a = y + 3
+                b = a + y
+            c = b * 2
+            return c
+
+        fn_script = torch.jit.script(fn)
+        op_graph = fn_script.graph
+        self.run_pass("common_expression_hoisting", op_graph)
+        self.run_pass("dce", op_graph)
+
+        FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
+        FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
+
+        t1 = torch.Tensor([1])
+        t2 = torch.Tensor([5, 6])
+        self.assertEqual(fn(True, t1), fn_script(True, t1))
+        self.assertEqual(fn(False, t2), fn_script(False, t2))
index 14bb31c..5fde8e2 100644 (file)
@@ -1214,6 +1214,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
             def __init__(self):
                 super(Res, self).__init__()
                 self.conv = torch.nn.Conv2d(3, 3, 1).float()
+                self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
                 self.use_skip = True
 
             def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
@@ -1222,7 +1223,7 @@ class TestQuantizeJitPasses(QuantizationTestCase):
                 if self.use_skip:
                     return self.conv(x)
                 else:
-                    return self.conv(x)
+                    return self.conv2(x)
 
         class M(torch.nn.Module):
             def __init__(self):
index 99df960..6cf1d8e 100644 (file)
@@ -23,6 +23,7 @@ from jit.test_class_type import TestClassType  # noqa: F401
 from jit.test_builtins import TestBuiltins, TestTensorBuiltins  # noqa: F401
 from jit.test_ignore_context_manager import TestIgnoreContextManager  # noqa: F401
 from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis  # noqa: F401
+from jit.test_if_hoisting import TestIfHoisting  # noqa: F401
 from jit.test_unsupported_ops import TestUnsupportedOps  # noqa: F401
 from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing  # noqa: F401
 from jit.test_peephole import TestPeephole  # noqa: F401
index 89697b4..2e71bed 100644 (file)
@@ -191,6 +191,7 @@ core_sources_full_mobile_no_backend_interface = [
     "torch/csrc/jit/passes/clear_profiling.cpp",
     "torch/csrc/jit/passes/clear_undefinedness.cpp",
     "torch/csrc/jit/passes/common_subexpression_elimination.cpp",
+    "torch/csrc/jit/passes/common_expression_hoisting.cpp",
     "torch/csrc/jit/passes/concat_opt.cpp",
     "torch/csrc/jit/passes/constant_pooling.cpp",
     "torch/csrc/jit/passes/constant_propagation.cpp",
index b683a60..30885d3 100644 (file)
@@ -204,6 +204,7 @@ def _jit_pass_inline(Graph) -> None: ...
 def _jit_pass_constant_propagation(Graph) -> None: ...
 def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
 def _jit_erase_non_input_shape_information(Graph) -> None: ...
+def _jit_pass_common_expression_hoisting(Graph) -> None: ...
 def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
 def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
 def _jit_can_fuse_on_cpu() -> _bool: ...
index 3fd4974..9a876d0 100644 (file)
@@ -204,6 +204,8 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
 
 } // anonymous namespace
 
+// Makes a hash that hashes the input Value, the output type
+// as well as the node attributes
 size_t HashNode::operator()(const Node* k) const {
   AT_ASSERT(k != nullptr);
   size_t constant_hash = 0;
@@ -231,6 +233,8 @@ size_t HashNode::operator()(const Node* k) const {
       constant_hash);
 };
 
+// Checks that two nodes have the same inputs, output types
+// and node attributes.
 bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
   if (lhs == nullptr && rhs == nullptr)
     return true;
@@ -261,6 +265,16 @@ bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
   if (!attributesEqualCSE(lhs, rhs))
     return false;
 
+  // Check if the blocks contained in a op are the same
+  if (lhs->blocks().size() != rhs->blocks().size()) {
+    return false;
+  }
+  for (size_t i = 0; i < lhs->blocks().size(); ++i) {
+    if (lhs->blocks()[i] != rhs->blocks()[i]) {
+      return false;
+    }
+  }
+
   return true;
 };
 
diff --git a/torch/csrc/jit/passes/common_expression_hoisting.cpp b/torch/csrc/jit/passes/common_expression_hoisting.cpp
new file mode 100644 (file)
index 0000000..ab2b9d4
--- /dev/null
@@ -0,0 +1,153 @@
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
+
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/node_hashing.h>
+#include <torch/csrc/jit/jit_log.h>
+
+#include <cstddef>
+#include <unordered_set>
+#include <vector>
+
+namespace torch {
+namespace jit {
+namespace {
+
+struct CommonExpressionHoister {
+  CommonExpressionHoister(std::shared_ptr<Graph> graph)
+      : graph_(std::move(graph)) {}
+
+  bool run() {
+    HoistCommonExpression(graph_->block());
+    return changed_;
+  }
+
+  void HoistFromIfNode(Node* if_node) {
+    Block* true_block = if_node->blocks()[0];
+    Block* false_block = if_node->blocks()[1];
+    // find common statements in the two subblocks
+
+    auto true_block_nodes = std::unordered_set<Node*, HashNode, EqualNode>(
+        true_block->nodes().begin(), true_block->nodes().end());
+    for (auto it = false_block->nodes().begin();
+         it != false_block->nodes().end();) {
+      Node* false_b_node = *it;
+      // node may be moved to a different block so advance iterator now
+      ++it;
+
+      auto matching_elem = true_block_nodes.find(false_b_node);
+      if (matching_elem == true_block_nodes.end()) {
+        continue;
+      }
+      Node* true_b_node = *matching_elem;
+
+      // Check if a move to the front of the block is valid
+      // If both of the moves are valid, then we know we can move the item out
+      // of the if blocks entirely.
+      AliasDb& aliasDb = getOrCreateAliasDb();
+      bool true_moveable = aliasDb.couldMoveAfterTopologically(
+          true_b_node, true_block->nodes().front());
+      bool false_moveable = aliasDb.couldMoveAfterTopologically(
+          false_b_node, false_block->nodes().front());
+
+      if (!true_moveable || !false_moveable) {
+        continue;
+      }
+
+      // Get all the uses of the output to delete and reinsert them
+      // as the input would change, the HashNode value would also change.
+      std::unordered_set<Node*> true_b_uses;
+      for (Value* true_out : true_b_node->outputs()) {
+        for (Use true_use : true_out->uses()) {
+          if (true_use.user->owningBlock() == true_block) {
+            // Make sure we are not accidentally adding stuff from subblocks
+            true_b_uses.insert(true_use.user);
+          }
+        }
+      }
+      for (Node* uses_node : true_b_uses) {
+        true_block_nodes.erase(uses_node);
+      }
+
+      // Now hoist the statement out of the block
+      changed_ = true;
+      false_b_node->moveBefore(if_node);
+
+      true_b_node->replaceAllUsesWith(false_b_node);
+
+      true_block_nodes.erase(true_b_node);
+      true_block_nodes.insert(true_b_uses.cbegin(), true_b_uses.cend());
+      true_b_node->destroy();
+    }
+  }
+
+  void EliminateUnnecessaryIfOutputs(Node* if_node) {
+    Block* true_block = if_node->blocks()[0];
+    Block* false_block = if_node->blocks()[1];
+
+    // fix up the if block outputs
+    for (size_t i = 0; i < true_block->outputs().size();) {
+      // Need to check both sides match to eliminate common if block outputs
+      Value* true_block_output = true_block->outputs().at(i);
+      Value* false_block_output = false_block->outputs().at(i);
+      if (true_block_output != false_block_output) {
+        i++;
+        continue;
+      }
+
+      // We have a matching output, and can remove it from the block itself
+      if_node->outputs().at(i)->replaceAllUsesWith(true_block_output);
+      if_node->eraseOutput(i);
+      true_block->eraseOutput(i);
+      false_block->eraseOutput(i);
+      changed_ = true;
+    }
+
+    // No need to test here if the IF block should be eliminated.
+    // The DCE pass will determine that for us.
+  }
+
+  void HoistCommonExpression(Block* block) {
+    for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+      Node* node = *it;
+      ++it;
+
+      for (auto sub_block : node->blocks()) {
+        HoistCommonExpression(sub_block);
+      }
+
+      if (node->kind() == prim::If) {
+        HoistFromIfNode(node);
+        EliminateUnnecessaryIfOutputs(node);
+      }
+    }
+  }
+
+  AliasDb& getOrCreateAliasDb() {
+    if (!alias_db_) {
+      alias_db_ = std::make_unique<AliasDb>(graph_);
+    }
+
+    return *alias_db_;
+  }
+
+ private:
+  std::unique_ptr<AliasDb> alias_db_;
+  std::shared_ptr<Graph> graph_;
+  bool changed_ = false;
+};
+} // anonymous namespace
+bool HoistCommonExpression(const std::shared_ptr<Graph>& graph) {
+  // This moves common subexpressions from the two sides of an
+  // if block out of the if block.
+
+  GRAPH_DUMP("Before CEH", graph);
+  CommonExpressionHoister ceh(graph);
+  bool changed = ceh.run();
+  if (changed) {
+    GRAPH_DUMP("After CEH Changes", graph);
+  }
+  return changed;
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/common_expression_hoisting.h b/torch/csrc/jit/passes/common_expression_hoisting.h
new file mode 100644 (file)
index 0000000..2aad158
--- /dev/null
@@ -0,0 +1,10 @@
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+
+namespace torch {
+namespace jit {
+
+TORCH_API bool HoistCommonExpression(const std::shared_ptr<Graph>& graph);
+}
+} // namespace torch
index f74a911..10edfb4 100644 (file)
@@ -6,6 +6,7 @@
 #include <torch/csrc/jit/ir/ir.h>
 #include <torch/csrc/jit/ir/ir_views.h>
 #include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/constant_propagation.h>
index 5fca575..d582035 100644 (file)
@@ -12,6 +12,7 @@
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/canonicalize.h>
 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/constant_propagation.h>
@@ -283,6 +284,11 @@ void initJITBindings(PyObject* module) {
             return EliminateCommonSubexpression(g); // overload resolution
           })
       .def(
+          "_jit_pass_common_expression_hoisting",
+          [](std::shared_ptr<Graph>& g) {
+            return HoistCommonExpression(g); // overload resolution
+          })
+      .def(
           "_jit_pass_fuse_quantized_add_relu",
           [](std::shared_ptr<Graph>& g) {
             return FuseQuantizedAddRelu(g); // overload resolution
index 4768826..bb5f272 100644 (file)
@@ -9,6 +9,7 @@
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/batch_mm.h>
 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/constant_propagation.h>
@@ -918,7 +919,7 @@ void runOptimization(
       "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph);
   EliminateCommonSubexpression(graph);
   GRAPH_DEBUG(
-      "After EliminateCommonSubexpression, before PeepholeOptimize\n", *graph);
+      "After EliminateCommonSubexpression , before PeepholeOptimize\n", *graph);
 
   PeepholeOptimize(graph);
   GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
@@ -949,8 +950,10 @@ void runOptimization(
 
   EliminateCommonSubexpression(graph);
   GRAPH_DEBUG(
-      "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
-
+      "After EliminateCommonSubexpression, before HoistCommonExpression\n",
+      *graph);
+  HoistCommonExpression(graph);
+  GRAPH_DEBUG("After HoistCommonExpression, before CheckInplace\n", *graph);
   CheckInplace(graph);
   GRAPH_DEBUG("After CheckInplace (end of runOptimization)", *graph);
 }
index b099db1..40d94a4 100644 (file)
@@ -7,6 +7,7 @@
 #include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
 #include <torch/csrc/jit/passes/clear_profiling.h>
 #include <torch/csrc/jit/passes/clear_undefinedness.h>
+#include <torch/csrc/jit/passes/common_expression_hoisting.h>
 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
 #include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/constant_propagation.h>
@@ -332,112 +333,16 @@ void runPreAutodiffPassPipeline(std::shared_ptr<Graph>& graph) {
 
     EliminateCommonSubexpression(graph);
     GRAPH_DEBUG(
-        "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
-
+        "After EliminateCommonSubexpression, before HoistCommonExpression\n",
+        *graph);
+    HoistCommonExpression(graph);
+    GRAPH_DEBUG("After HoistCommonExpression, before CheckInplace\n", *graph);
     CheckInplace(graph);
   }
   GRAPH_DEBUG(
       "After CheckInplace (end of runPreAutodiffPassPipeline)\n", *graph);
 }
 
-void runDiffGraphPasses(std::shared_ptr<Graph>& graph) {
-  GRAPH_DEBUG(
-      "Before EliminateDeadCode (beginning of runDiffGraphPasses)\n", *graph);
-  // runOptimization:
-  {
-    // Basic graph preprocessing to eliminate noise.
-    EliminateDeadCode(graph);
-    GRAPH_DEBUG(
-        "After EliminateDeadCode, before EliminateCommonSubexpression\n",
-        *graph);
-    EliminateCommonSubexpression(graph);
-    GRAPH_DEBUG(
-        "After EliminateCommonSubexpression, before PeepholeOptimize\n",
-        *graph);
-
-    PeepholeOptimize(graph);
-    GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
-    ConstantPropagation(graph);
-    GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph);
-    ConstantPooling(graph);
-    GRAPH_DEBUG("After ConstantPooling, before UnrollLoops\n", *graph);
-
-    UnrollLoops(graph);
-    GRAPH_DEBUG("After UnrollLoops, before RemoveListMutation\n", *graph);
-    // run again with unrolled loops
-    RemoveListMutation(graph);
-    GRAPH_DEBUG("After RemoveListMutation, before PeepholeOptimize\n", *graph);
-    PeepholeOptimize(graph);
-    GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
-    ConstantPropagation(graph);
-    GRAPH_DEBUG(
-        "After ConstantPropagation, before EliminateCommonSubexpression\n",
-        *graph);
-
-    EliminateCommonSubexpression(graph);
-    GRAPH_DEBUG(
-        "After EliminateCommonSubexpression, before CheckInplace\n", *graph);
-
-    CheckInplace(graph);
-  }
-  GRAPH_DEBUG("After CheckInplace, before customPrePasses\n", *graph);
-
-  // runNondiffOptimization
-  {
-    // Run custom passes that different backends can register.
-    for (const auto& passPair : getCustomPrePasses()) {
-      passPair.first(graph);
-    }
-    GRAPH_DEBUG("After customPrePasses, before LowerSimpleTuples\n", *graph);
-
-    // TupleConstruct / TupleUnpack pairs can still be present at this point
-    // and must be removed for fusion.
-    LowerSimpleTuples(graph);
-    GRAPH_DEBUG("After LowerSimpleTuples\n", *graph);
-
-    if (tensorExprFuserEnabled()) {
-      // Remove prim::profile nodes and embed the profile info directly in the
-      // IR in value types. We're doing such transformation as optimizations
-      // that try to merge/fuse nodes in the graph (e.g. BatchMM and GraphFuser)
-      // work worse in the presence of intermittent prim::profile nodes.
-      // Optimizations relying on the type info are also responsible for
-      // inserting proper type checks. Once we're done with these optimizations
-      // we will wipe the tensor type information from the IR, so that it's not
-      // accidentally used by any other pass.
-      RemoveProfileNodesAndSpecializeTypes(graph);
-      GRAPH_DEBUG(
-          "After RemoveProfileNodesAndSpecializeTypes, before BatchMM\n",
-          *graph);
-      // Rewrite subgraphs with many MMs into expressions that batch them.
-      BatchMM(graph);
-      GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
-
-      FuseTensorExprs(graph, getFusionGroupInlining() ? 2 : 1);
-      GRAPH_DEBUG(
-          "After Fusion, before RemoveTensorTypeSpecializations\n", *graph);
-
-      // Wipe tensor type info from the IR
-      RemoveTensorTypeSpecializations(graph);
-      GRAPH_DEBUG(
-          "After RemoveTensorTypeSpecializations, before customPostPasses\n",
-          *graph);
-    } else {
-      // Rewrite subgraphs with many MMs into expressions that batch them.
-      BatchMM(graph);
-      GRAPH_DEBUG("After BatchMM, before Fusion\n", *graph);
-
-      FuseGraph(graph, true);
-      GRAPH_DEBUG("After Fusion, before customPostPasses\n", *graph);
-    }
-
-    // Run custom post-fusion passes
-    for (const auto& passPair : getCustomPostPasses()) {
-      passPair.first(graph);
-    }
-  }
-  GRAPH_DEBUG("After customPostPasses (end of runDiffGraphPasses)\n", *graph);
-}
-
 void runNoGradOptimizations(std::shared_ptr<Graph>& graph) {
   GRAPH_DEBUG(
       "After customPostPasses (beginning of runNoGradOptimizations)\n", *graph);
@@ -593,7 +498,11 @@ void ProfilingGraphExecutorImpl::runProfilingInsensitiveOptimizations(
   DecomposeOps(graph);
   GRAPH_DEBUG("After DecomposeOps, before ConstantPropagation\n", *graph);
   ConstantPropagation(graph);
-  GRAPH_DEBUG("After ConstantPropagation, before EliminateDeadCode\n", *graph);
+  GRAPH_DEBUG(
+      "After ConstantPropagation, before HoistCommonExpression\n", *graph);
+  HoistCommonExpression(graph);
+  GRAPH_DEBUG(
+      "After EliminateCommonSubexpression, before ElimiateDeadCode\n", *graph);
   EliminateDeadCode(graph);
   GRAPH_DEBUG(
       "After EliminateDeadCode, before EliminateCommonSubexpression\n", *graph);