[mlir][CSE] Remove duplicated operations with MemRead side-effect
authorValentin Clement <clementval@gmail.com>
Thu, 7 Apr 2022 08:06:50 +0000 (10:06 +0200)
committerValentin Clement <clementval@gmail.com>
Thu, 7 Apr 2022 08:08:55 +0000 (10:08 +0200)
This patch enhances the CSE pass to deal with simple cases of duplicated
operations with MemoryEffects.

It allows the CSE pass to remove safely duplicate operations with the
MemoryEffects::Read that have no other side-effecting operations in
between. Other MemoryEffects::Read operation are allowed.

The use case is pretty simple so far so we can build on top of it to add
more features.

This patch is also meant to avoid a dedicated CSE pass in FIR and was
brought together afetr discussion on https://reviews.llvm.org/D112711.
It does not currently cover the full range of use cases described in
https://reviews.llvm.org/D112711 but the idea is to gradually enhance
the MLIR CSE pass to handle common use cases that can be used by
other dialects.

This patch takes advantage of the new CSE capabilities in Fir.

Reviewed By: mehdi_amini, rriddle, schweitz

Differential Revision: https://reviews.llvm.org/D122801

flang/include/flang/Optimizer/Dialect/FIROps.td
flang/test/Fir/cse.fir [new file with mode: 0644]
mlir/lib/Transforms/CSE.cpp
mlir/test/Examples/Toy/Ch5/affine-lowering.mlir
mlir/test/Examples/Toy/Ch6/affine-lowering.mlir
mlir/test/Examples/Toy/Ch7/affine-lowering.mlir
mlir/test/Transforms/cse.mlir
mlir/test/lib/Dialect/Test/TestOps.td

index 262d953..6eb0fdf 100644 (file)
@@ -253,7 +253,7 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
   let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
 }
 
-def fir_LoadOp : fir_OneResultOp<"load"> {
+def fir_LoadOp : fir_OneResultOp<"load", [MemoryEffects<[MemRead]>]> {
   let summary = "load a value from a memory reference";
   let description = [{
     Load a value from a memory reference into an ssa-value (virtual register).
@@ -320,7 +320,7 @@ def fir_CharConvertOp : fir_Op<"char_convert", []> {
   let hasVerifier = 1;
 }
 
-def fir_StoreOp : fir_Op<"store", []> {
+def fir_StoreOp : fir_Op<"store", [MemoryEffects<[MemWrite]>]> {
   let summary = "store an SSA-value to a memory location";
 
   let description = [{
diff --git a/flang/test/Fir/cse.fir b/flang/test/Fir/cse.fir
new file mode 100644 (file)
index 0000000..148b689
--- /dev/null
@@ -0,0 +1,57 @@
+// RUN: fir-opt --cse -split-input-file %s | FileCheck %s
+
+// Check that the redundant fir.load is removed.
+func @fun(%arg0: !fir.ref<i64>) -> i64 {
+    %0 = fir.load %arg0 : !fir.ref<i64>
+    %1 = fir.load %arg0 : !fir.ref<i64>
+    %2 = arith.addi %0, %1 : i64
+    return %2 : i64
+}
+
+// CHECK-LABEL: func @fun
+// CHECK-NEXT:    %[[LOAD:.*]] = fir.load %{{.*}} : !fir.ref<i64>
+// CHECK-NEXT:    %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64
+
+// -----
+
+// CHECK-LABEL: func @fun(
+// CHECK-SAME:            %[[A:.*]]: !fir.ref<i64>
+func @fun(%a : !fir.ref<i64>) -> i64 {
+  // CHECK: %[[LOAD:.*]] = fir.load %[[A]] : !fir.ref<i64>
+  %1 = fir.load %a : !fir.ref<i64>
+  %2 = fir.load %a : !fir.ref<i64>
+  // CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64
+  %3 = arith.addi %1, %2 : i64
+  %4 = fir.load %a : !fir.ref<i64>
+  // CHECK-NEXT: %{{.*}} = arith.addi
+  %5 = arith.addi %3, %4 : i64
+  %6 = fir.load %a : !fir.ref<i64>
+  // CHECK-NEXT: %{{.*}} = arith.addi
+  %7 = arith.addi %5, %6 : i64
+  %8 = fir.load %a : !fir.ref<i64>
+  // CHECK-NEXT: %{{.*}} = arith.addi
+  %9 = arith.addi %7, %8 : i64
+  %10 = fir.load %a : !fir.ref<i64>
+  // CHECK-NEXT: %{{.*}} = arith.addi
+  %11 = arith.addi %10, %9 : i64
+  %12 = fir.load %a : !fir.ref<i64>
+  // CHECK-NEXT: %{{.*}} = arith.addi
+  %13 = arith.addi %11, %12 : i64
+  // CHECK-NEXT: return %{{.*}} : i64
+  return %13 : i64
+}
+
+// -----
+
+func @fun(%a : !fir.ref<i64>) -> i64 {
+  cf.br ^bb1
+^bb1:
+  %1 = fir.load %a : !fir.ref<i64>
+  %2 = fir.load %a : !fir.ref<i64>
+  %3 = arith.addi %1, %2 : i64
+  cf.br ^bb2
+^bb2:
+  %4 = fir.load %a : !fir.ref<i64>
+  %5 = arith.subi %4, %4 : i64
+  return %5 : i64
+}
index 0570c91..080e393 100644 (file)
@@ -60,6 +60,14 @@ struct CSE : public CSEBase<CSE> {
   using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
                                             SimpleOperationInfo, AllocatorTy>;
 
+  /// Cache holding MemoryEffects information between two operations. The first
+  /// operation is stored has the key. The second operation is stored inside a
+  /// pair in the value. The pair also hold the MemoryEffects between those
+  /// two operations. If the MemoryEffects is nullptr then we assume there is
+  /// no operation with MemoryEffects::Write between the two operations.
+  using MemEffectsCache =
+      DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
+
   /// Represents a single entry in the depth first traversal of a CFG.
   struct CFGStackNode {
     CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
@@ -85,12 +93,94 @@ struct CSE : public CSEBase<CSE> {
   void runOnOperation() override;
 
 private:
+  void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+                            Operation *existing, bool hasSSADominance);
+
+  /// Check if there is side-effecting operations other than the given effect
+  /// between the two operations.
+  bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
+
   /// Operations marked as dead and to be erased.
   std::vector<Operation *> opsToErase;
   DominanceInfo *domInfo = nullptr;
+  MemEffectsCache memEffectsCache;
 };
 } // namespace
 
+void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
+                               Operation *existing, bool hasSSADominance) {
+  // If we find one then replace all uses of the current operation with the
+  // existing one and mark it for deletion. We can only replace an operand in
+  // an operation if it has not been visited yet.
+  if (hasSSADominance) {
+    // If the region has SSA dominance, then we are guaranteed to have not
+    // visited any use of the current operation.
+    op->replaceAllUsesWith(existing);
+    opsToErase.push_back(op);
+  } else {
+    // When the region does not have SSA dominance, we need to check if we
+    // have visited a use before replacing any use.
+    for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
+      std::get<0>(it).replaceUsesWithIf(
+          std::get<1>(it), [&](OpOperand &operand) {
+            return !knownValues.count(operand.getOwner());
+          });
+    }
+
+    // There may be some remaining uses of the operation.
+    if (op->use_empty())
+      opsToErase.push_back(op);
+  }
+
+  // If the existing operation has an unknown location and the current
+  // operation doesn't, then set the existing op's location to that of the
+  // current op.
+  if (existing->getLoc().isa<UnknownLoc>() && !op->getLoc().isa<UnknownLoc>())
+    existing->setLoc(op->getLoc());
+
+  ++numCSE;
+}
+
+bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) {
+  assert(fromOp->getBlock() == toOp->getBlock());
+  assert(
+      isa<MemoryEffectOpInterface>(fromOp) &&
+      cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
+      isa<MemoryEffectOpInterface>(toOp) &&
+      cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
+  Operation *nextOp = fromOp->getNextNode();
+  auto result =
+      memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
+  if (result.second) {
+    auto memEffectsCachePair = result.first->second;
+    if (memEffectsCachePair.second == nullptr) {
+      // No MemoryEffects::Write has been detected until the cached operation.
+      // Continue looking from the cached operation to toOp.
+      nextOp = memEffectsCachePair.first;
+    } else {
+      // MemoryEffects::Write has been detected before so there is no need to
+      // check further.
+      return true;
+    }
+  }
+  while (nextOp && nextOp != toOp) {
+    auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
+    // TODO: Do we need to handle other effects generically?
+    // If the operation does not implement the MemoryEffectOpInterface we
+    // conservatively assumes it writes.
+    if ((nextOpMemEffects &&
+         nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
+        !nextOpMemEffects) {
+      result.first->second =
+          std::make_pair(nextOp, MemoryEffects::Write::get());
+      return true;
+    }
+    nextOp = nextOp->getNextNode();
+  }
+  result.first->second = std::make_pair(toOp, nullptr);
+  return false;
+}
+
 /// Attempt to eliminate a redundant operation.
 LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
                                      bool hasSSADominance) {
@@ -111,45 +201,34 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
   if (op->getNumRegions() != 0)
     return failure();
 
-  // TODO: We currently only eliminate non side-effecting
-  // operations.
-  if (!MemoryEffectOpInterface::hasNoEffect(op))
+  // Some simple use case of operation with memory side-effect are dealt with
+  // here. Operations with no side-effect are done after.
+  if (!MemoryEffectOpInterface::hasNoEffect(op)) {
+    auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
+    // TODO: Only basic use case for operations with MemoryEffects::Read can be
+    // eleminated now. More work needs to be done for more complicated patterns
+    // and other side-effects.
+    if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
+      return failure();
+
+    // Look for an existing definition for the operation.
+    if (auto *existing = knownValues.lookup(op)) {
+      if (existing->getBlock() == op->getBlock() &&
+          !hasOtherSideEffectingOpInBetween(existing, op)) {
+        // The operation that can be deleted has been reach with no
+        // side-effecting operations in between the existing operation and
+        // this one so we can remove the duplicate.
+        replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
+        return success();
+      }
+    }
+    knownValues.insert(op, op);
     return failure();
+  }
 
   // Look for an existing definition for the operation.
   if (auto *existing = knownValues.lookup(op)) {
-
-    // If we find one then replace all uses of the current operation with the
-    // existing one and mark it for deletion. We can only replace an operand in
-    // an operation if it has not been visited yet.
-    if (hasSSADominance) {
-      // If the region has SSA dominance, then we are guaranteed to have not
-      // visited any use of the current operation.
-      op->replaceAllUsesWith(existing);
-      opsToErase.push_back(op);
-    } else {
-      // When the region does not have SSA dominance, we need to check if we
-      // have visited a use before replacing any use.
-      for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
-        std::get<0>(it).replaceUsesWithIf(
-            std::get<1>(it), [&](OpOperand &operand) {
-              return !knownValues.count(operand.getOwner());
-            });
-      }
-
-      // There may be some remaining uses of the operation.
-      if (op->use_empty())
-        opsToErase.push_back(op);
-    }
-
-    // If the existing operation has an unknown location and the current
-    // operation doesn't, then set the existing op's location to that of the
-    // current op.
-    if (existing->getLoc().isa<UnknownLoc>() &&
-        !op->getLoc().isa<UnknownLoc>()) {
-      existing->setLoc(op->getLoc());
-    }
-
+    replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
     ++numCSE;
     return success();
   }
@@ -184,6 +263,8 @@ void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
     for (auto &region : op.getRegions())
       simplifyRegion(knownValues, region);
   }
+  // Clear the MemoryEffects cache since its usage is by block only.
+  memEffectsCache.clear();
 }
 
 void CSE::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
index ad99faa..034474d 100644 (file)
@@ -32,8 +32,7 @@ toy.func @main() {
 // CHECK:         affine.for [[VAL_12:%.*]] = 0 to 3 {
 // CHECK:           affine.for [[VAL_13:%.*]] = 0 to 2 {
 // CHECK:             [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK:             [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK:             [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK:             [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
 // CHECK:             affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
 // CHECK:         toy.print [[VAL_6]] : memref<3x2xf64>
 // CHECK:         memref.dealloc [[VAL_8]] : memref<2x3xf64>
index ca056b4..51dedaf 100644 (file)
@@ -32,8 +32,7 @@ toy.func @main() {
 // CHECK:         affine.for [[VAL_12:%.*]] = 0 to 3 {
 // CHECK:           affine.for [[VAL_13:%.*]] = 0 to 2 {
 // CHECK:             [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK:             [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK:             [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK:             [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
 // CHECK:             affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
 // CHECK:         toy.print [[VAL_6]] : memref<3x2xf64>
 // CHECK:         memref.dealloc [[VAL_8]] : memref<2x3xf64>
index 60d466e..3cefd0e 100644 (file)
@@ -32,8 +32,7 @@ toy.func @main() {
 // CHECK:         affine.for [[VAL_12:%.*]] = 0 to 3 {
 // CHECK:           affine.for [[VAL_13:%.*]] = 0 to 2 {
 // CHECK:             [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK:             [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
-// CHECK:             [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
+// CHECK:             [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
 // CHECK:             affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
 // CHECK:         toy.print [[VAL_6]] : memref<3x2xf64>
 // CHECK:         memref.dealloc [[VAL_8]] : memref<2x3xf64>
index 982511f..189cdde 100644 (file)
@@ -265,3 +265,48 @@ func @use_before_def() {
   }
   return
 } 
+
+/// This test is checking that CSE is removing duplicated read op that follow
+/// other.
+// CHECK-LABEL: @remove_direct_duplicated_read_op
+func @remove_direct_duplicated_read_op() -> i32 {
+  // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
+  %0 = "test.op_with_memread"() : () -> (i32)
+  %1 = "test.op_with_memread"() : () -> (i32)
+  // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE]], %[[READ_VALUE]] : i32
+  %2 = arith.addi %0, %1 : i32
+  return %2 : i32
+}
+
+/// This test is checking that CSE is removing duplicated read op that follow
+/// other.
+// CHECK-LABEL: @remove_multiple_duplicated_read_op
+func @remove_multiple_duplicated_read_op() -> i64 {
+  // CHECK: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i64
+  %0 = "test.op_with_memread"() : () -> (i64)
+  %1 = "test.op_with_memread"() : () -> (i64)
+  // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %[[READ_VALUE]] : i64
+  %2 = arith.addi %0, %1 : i64
+  %3 = "test.op_with_memread"() : () -> (i64)
+  // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
+  %4 = arith.addi %2, %3 : i64
+  %5 = "test.op_with_memread"() : () -> (i64)
+  // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
+  %6 = arith.addi %4, %5 : i64
+  // CHECK-NEXT: return %{{.*}} : i64
+  return %6 : i64
+}
+
+/// This test is checking that CSE is not removing duplicated read op that
+/// have write op in between.
+// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
+func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
+  // CHECK-NEXT: %[[READ_VALUE0:.*]] = "test.op_with_memread"() : () -> i32
+  %0 = "test.op_with_memread"() : () -> (i32)
+  "test.op_with_memwrite"() : () -> ()
+  // CHECK: %[[READ_VALUE1:.*]] = "test.op_with_memread"() : () -> i32
+  %1 = "test.op_with_memread"() : () -> (i32)
+  // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE0]], %[[READ_VALUE1]] : i32
+  %2 = arith.addi %0, %1 : i32
+  return %2 : i32
+}
index b157d5d..36e31d1 100644 (file)
@@ -2761,4 +2761,12 @@ def TestEffectsOpA : TEST_Op<"op_with_effects_a"> {
 def TestEffectsOpB : TEST_Op<"op_with_effects_b",
     [MemoryEffects<[MemWrite<TestResource>]>]>;
 
+def TestEffectsRead : TEST_Op<"op_with_memread",
+    [MemoryEffects<[MemRead]>]> {
+  let results = (outs AnyInteger);
+}
+
+def TestEffectsWrite : TEST_Op<"op_with_memwrite",
+    [MemoryEffects<[MemWrite]>]>;
+
 #endif // TEST_OPS