From 02da9643506dee4a82353e0f911513279634d846 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 7 Apr 2022 10:06:50 +0200 Subject: [PATCH] [mlir][CSE] Remove duplicated operations with MemRead side-effect This patch enhances the CSE pass to deal with simple cases of duplicated operations with MemoryEffects. It allows the CSE pass to remove safely duplicate operations with the MemoryEffects::Read that have no other side-effecting operations in between. Other MemoryEffects::Read operation are allowed. The use case is pretty simple so far so we can build on top of it to add more features. This patch is also meant to avoid a dedicated CSE pass in FIR and was brought together afetr discussion on https://reviews.llvm.org/D112711. It does not currently cover the full range of use cases described in https://reviews.llvm.org/D112711 but the idea is to gradually enhance the MLIR CSE pass to handle common use cases that can be used by other dialects. This patch takes advantage of the new CSE capabilities in Fir. Reviewed By: mehdi_amini, rriddle, schweitz Differential Revision: https://reviews.llvm.org/D122801 --- flang/include/flang/Optimizer/Dialect/FIROps.td | 4 +- flang/test/Fir/cse.fir | 57 +++++++++ mlir/lib/Transforms/CSE.cpp | 151 ++++++++++++++++++------ mlir/test/Examples/Toy/Ch5/affine-lowering.mlir | 3 +- mlir/test/Examples/Toy/Ch6/affine-lowering.mlir | 3 +- mlir/test/Examples/Toy/Ch7/affine-lowering.mlir | 3 +- mlir/test/Transforms/cse.mlir | 45 +++++++ mlir/test/lib/Dialect/Test/TestOps.td | 8 ++ 8 files changed, 231 insertions(+), 43 deletions(-) create mode 100644 flang/test/Fir/cse.fir diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 262d953..6eb0fdf 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -253,7 +253,7 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> { let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))"; } -def fir_LoadOp : fir_OneResultOp<"load"> { +def fir_LoadOp : fir_OneResultOp<"load", [MemoryEffects<[MemRead]>]> { let summary = "load a value from a memory reference"; let description = [{ Load a value from a memory reference into an ssa-value (virtual register). @@ -320,7 +320,7 @@ def fir_CharConvertOp : fir_Op<"char_convert", []> { let hasVerifier = 1; } -def fir_StoreOp : fir_Op<"store", []> { +def fir_StoreOp : fir_Op<"store", [MemoryEffects<[MemWrite]>]> { let summary = "store an SSA-value to a memory location"; let description = [{ diff --git a/flang/test/Fir/cse.fir b/flang/test/Fir/cse.fir new file mode 100644 index 0000000..148b689 --- /dev/null +++ b/flang/test/Fir/cse.fir @@ -0,0 +1,57 @@ +// RUN: fir-opt --cse -split-input-file %s | FileCheck %s + +// Check that the redundant fir.load is removed. +func @fun(%arg0: !fir.ref) -> i64 { + %0 = fir.load %arg0 : !fir.ref + %1 = fir.load %arg0 : !fir.ref + %2 = arith.addi %0, %1 : i64 + return %2 : i64 +} + +// CHECK-LABEL: func @fun +// CHECK-NEXT: %[[LOAD:.*]] = fir.load %{{.*}} : !fir.ref +// CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64 + +// ----- + +// CHECK-LABEL: func @fun( +// CHECK-SAME: %[[A:.*]]: !fir.ref +func @fun(%a : !fir.ref) -> i64 { + // CHECK: %[[LOAD:.*]] = fir.load %[[A]] : !fir.ref + %1 = fir.load %a : !fir.ref + %2 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64 + %3 = arith.addi %1, %2 : i64 + %4 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %5 = arith.addi %3, %4 : i64 + %6 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %7 = arith.addi %5, %6 : i64 + %8 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %9 = arith.addi %7, %8 : i64 + %10 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %11 = arith.addi %10, %9 : i64 + %12 = fir.load %a : !fir.ref + // CHECK-NEXT: %{{.*}} = arith.addi + %13 = arith.addi %11, %12 : i64 + // CHECK-NEXT: return %{{.*}} : i64 + return %13 : i64 +} + +// ----- + +func @fun(%a : !fir.ref) -> i64 { + cf.br ^bb1 +^bb1: + %1 = fir.load %a : !fir.ref + %2 = fir.load %a : !fir.ref + %3 = arith.addi %1, %2 : i64 + cf.br ^bb2 +^bb2: + %4 = fir.load %a : !fir.ref + %5 = arith.subi %4, %4 : i64 + return %5 : i64 +} diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 0570c91..080e393 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -60,6 +60,14 @@ struct CSE : public CSEBase { using ScopedMapTy = llvm::ScopedHashTable; + /// Cache holding MemoryEffects information between two operations. The first + /// operation is stored has the key. The second operation is stored inside a + /// pair in the value. The pair also hold the MemoryEffects between those + /// two operations. If the MemoryEffects is nullptr then we assume there is + /// no operation with MemoryEffects::Write between the two operations. + using MemEffectsCache = + DenseMap>; + /// Represents a single entry in the depth first traversal of a CFG. struct CFGStackNode { CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node) @@ -85,12 +93,94 @@ struct CSE : public CSEBase { void runOnOperation() override; private: + void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, + Operation *existing, bool hasSSADominance); + + /// Check if there is side-effecting operations other than the given effect + /// between the two operations. + bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); + /// Operations marked as dead and to be erased. std::vector opsToErase; DominanceInfo *domInfo = nullptr; + MemEffectsCache memEffectsCache; }; } // namespace +void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, + Operation *existing, bool hasSSADominance) { + // If we find one then replace all uses of the current operation with the + // existing one and mark it for deletion. We can only replace an operand in + // an operation if it has not been visited yet. + if (hasSSADominance) { + // If the region has SSA dominance, then we are guaranteed to have not + // visited any use of the current operation. + op->replaceAllUsesWith(existing); + opsToErase.push_back(op); + } else { + // When the region does not have SSA dominance, we need to check if we + // have visited a use before replacing any use. + for (auto it : llvm::zip(op->getResults(), existing->getResults())) { + std::get<0>(it).replaceUsesWithIf( + std::get<1>(it), [&](OpOperand &operand) { + return !knownValues.count(operand.getOwner()); + }); + } + + // There may be some remaining uses of the operation. + if (op->use_empty()) + opsToErase.push_back(op); + } + + // If the existing operation has an unknown location and the current + // operation doesn't, then set the existing op's location to that of the + // current op. + if (existing->getLoc().isa() && !op->getLoc().isa()) + existing->setLoc(op->getLoc()); + + ++numCSE; +} + +bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) { + assert(fromOp->getBlock() == toOp->getBlock()); + assert( + isa(fromOp) && + cast(fromOp).hasEffect() && + isa(toOp) && + cast(toOp).hasEffect()); + Operation *nextOp = fromOp->getNextNode(); + auto result = + memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr)); + if (result.second) { + auto memEffectsCachePair = result.first->second; + if (memEffectsCachePair.second == nullptr) { + // No MemoryEffects::Write has been detected until the cached operation. + // Continue looking from the cached operation to toOp. + nextOp = memEffectsCachePair.first; + } else { + // MemoryEffects::Write has been detected before so there is no need to + // check further. + return true; + } + } + while (nextOp && nextOp != toOp) { + auto nextOpMemEffects = dyn_cast(nextOp); + // TODO: Do we need to handle other effects generically? + // If the operation does not implement the MemoryEffectOpInterface we + // conservatively assumes it writes. + if ((nextOpMemEffects && + nextOpMemEffects.hasEffect()) || + !nextOpMemEffects) { + result.first->second = + std::make_pair(nextOp, MemoryEffects::Write::get()); + return true; + } + nextOp = nextOp->getNextNode(); + } + result.first->second = std::make_pair(toOp, nullptr); + return false; +} + /// Attempt to eliminate a redundant operation. LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op, bool hasSSADominance) { @@ -111,45 +201,34 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op, if (op->getNumRegions() != 0) return failure(); - // TODO: We currently only eliminate non side-effecting - // operations. - if (!MemoryEffectOpInterface::hasNoEffect(op)) + // Some simple use case of operation with memory side-effect are dealt with + // here. Operations with no side-effect are done after. + if (!MemoryEffectOpInterface::hasNoEffect(op)) { + auto memEffects = dyn_cast(op); + // TODO: Only basic use case for operations with MemoryEffects::Read can be + // eleminated now. More work needs to be done for more complicated patterns + // and other side-effects. + if (!memEffects || !memEffects.onlyHasEffect()) + return failure(); + + // Look for an existing definition for the operation. + if (auto *existing = knownValues.lookup(op)) { + if (existing->getBlock() == op->getBlock() && + !hasOtherSideEffectingOpInBetween(existing, op)) { + // The operation that can be deleted has been reach with no + // side-effecting operations in between the existing operation and + // this one so we can remove the duplicate. + replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); + return success(); + } + } + knownValues.insert(op, op); return failure(); + } // Look for an existing definition for the operation. if (auto *existing = knownValues.lookup(op)) { - - // If we find one then replace all uses of the current operation with the - // existing one and mark it for deletion. We can only replace an operand in - // an operation if it has not been visited yet. - if (hasSSADominance) { - // If the region has SSA dominance, then we are guaranteed to have not - // visited any use of the current operation. - op->replaceAllUsesWith(existing); - opsToErase.push_back(op); - } else { - // When the region does not have SSA dominance, we need to check if we - // have visited a use before replacing any use. - for (auto it : llvm::zip(op->getResults(), existing->getResults())) { - std::get<0>(it).replaceUsesWithIf( - std::get<1>(it), [&](OpOperand &operand) { - return !knownValues.count(operand.getOwner()); - }); - } - - // There may be some remaining uses of the operation. - if (op->use_empty()) - opsToErase.push_back(op); - } - - // If the existing operation has an unknown location and the current - // operation doesn't, then set the existing op's location to that of the - // current op. - if (existing->getLoc().isa() && - !op->getLoc().isa()) { - existing->setLoc(op->getLoc()); - } - + replaceUsesAndDelete(knownValues, op, existing, hasSSADominance); ++numCSE; return success(); } @@ -184,6 +263,8 @@ void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb, for (auto ®ion : op.getRegions()) simplifyRegion(knownValues, region); } + // Clear the MemoryEffects cache since its usage is by block only. + memEffectsCache.clear(); } void CSE::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { diff --git a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir index ad99faa..034474d 100644 --- a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir +++ b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir @@ -32,8 +32,7 @@ toy.func @main() { // CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { // CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { // CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 // CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> // CHECK: toy.print [[VAL_6]] : memref<3x2xf64> // CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> diff --git a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir index ca056b4..51dedaf 100644 --- a/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir +++ b/mlir/test/Examples/Toy/Ch6/affine-lowering.mlir @@ -32,8 +32,7 @@ toy.func @main() { // CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { // CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { // CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 // CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> // CHECK: toy.print [[VAL_6]] : memref<3x2xf64> // CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> diff --git a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir index 60d466e..3cefd0e 100644 --- a/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir +++ b/mlir/test/Examples/Toy/Ch7/affine-lowering.mlir @@ -32,8 +32,7 @@ toy.func @main() { // CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 { // CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 { // CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> -// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64 // CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64> // CHECK: toy.print [[VAL_6]] : memref<3x2xf64> // CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64> diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index 982511f..189cdde 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -265,3 +265,48 @@ func @use_before_def() { } return } + +/// This test is checking that CSE is removing duplicated read op that follow +/// other. +// CHECK-LABEL: @remove_direct_duplicated_read_op +func @remove_direct_duplicated_read_op() -> i32 { + // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32 + %0 = "test.op_with_memread"() : () -> (i32) + %1 = "test.op_with_memread"() : () -> (i32) + // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE]], %[[READ_VALUE]] : i32 + %2 = arith.addi %0, %1 : i32 + return %2 : i32 +} + +/// This test is checking that CSE is removing duplicated read op that follow +/// other. +// CHECK-LABEL: @remove_multiple_duplicated_read_op +func @remove_multiple_duplicated_read_op() -> i64 { + // CHECK: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i64 + %0 = "test.op_with_memread"() : () -> (i64) + %1 = "test.op_with_memread"() : () -> (i64) + // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %[[READ_VALUE]] : i64 + %2 = arith.addi %0, %1 : i64 + %3 = "test.op_with_memread"() : () -> (i64) + // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 + %4 = arith.addi %2, %3 : i64 + %5 = "test.op_with_memread"() : () -> (i64) + // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64 + %6 = arith.addi %4, %5 : i64 + // CHECK-NEXT: return %{{.*}} : i64 + return %6 : i64 +} + +/// This test is checking that CSE is not removing duplicated read op that +/// have write op in between. +// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting +func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 { + // CHECK-NEXT: %[[READ_VALUE0:.*]] = "test.op_with_memread"() : () -> i32 + %0 = "test.op_with_memread"() : () -> (i32) + "test.op_with_memwrite"() : () -> () + // CHECK: %[[READ_VALUE1:.*]] = "test.op_with_memread"() : () -> i32 + %1 = "test.op_with_memread"() : () -> (i32) + // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE0]], %[[READ_VALUE1]] : i32 + %2 = arith.addi %0, %1 : i32 + return %2 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index b157d5d..36e31d1 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2761,4 +2761,12 @@ def TestEffectsOpA : TEST_Op<"op_with_effects_a"> { def TestEffectsOpB : TEST_Op<"op_with_effects_b", [MemoryEffects<[MemWrite]>]>; +def TestEffectsRead : TEST_Op<"op_with_memread", + [MemoryEffects<[MemRead]>]> { + let results = (outs AnyInteger); +} + +def TestEffectsWrite : TEST_Op<"op_with_memwrite", + [MemoryEffects<[MemWrite]>]>; + #endif // TEST_OPS -- 2.7.4