From 04b5274ede3ebc1de98c47e34cb762bae474696b Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Sun, 5 Apr 2020 08:10:33 +0530 Subject: [PATCH] [MLIR] Introduce applyOpPatternsAndFold for op local rewrites Introduce mlir::applyOpPatternsAndFold which applies patterns as well as any folding only on a specified op (in contrast to applyPatternsAndFoldGreedily which applies patterns only on the regions of an op isolated from above). The caller is made aware of the op being folded away or erased. Depends on D77485. Differential Revision: https://reviews.llvm.org/D77487 --- mlir/include/mlir/IR/PatternMatch.h | 9 ++ .../Affine/Transforms/AffineDataCopyGeneration.cpp | 31 +++--- .../Affine/Transforms/SimplifyAffineStructures.cpp | 21 +++- .../Utils/GreedyPatternRewriteDriver.cpp | 115 ++++++++++++++++++++- mlir/lib/Transforms/Utils/LoopUtils.cpp | 20 +++- mlir/test/Dialect/Affine/affine-data-copy.mlir | 10 +- .../Dialect/Affine/simplify-affine-structures.mlir | 84 +++++++++------ .../test/lib/Dialect/Affine/TestAffineDataCopy.cpp | 35 +++++-- 8 files changed, 258 insertions(+), 67 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 4679d98..6dbc5b9 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -455,6 +455,15 @@ bool applyPatternsAndFoldGreedily(Operation *op, /// Rewrite the given regions, which must be isolated from above. bool applyPatternsAndFoldGreedily(MutableArrayRef regions, const OwningRewritePatternList &patterns); + +/// Applies the specified patterns on `op` alone while also trying to fold it, +/// by selecting the highest benefits patterns in a greedy manner. Returns true +/// if no more patterns can be matched. `erased` is set to true if `op` was +/// folded away or erased as a result of becoming dead. Note: This does not +/// apply any patterns recursively to the regions of `op`. +bool applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool *erased = nullptr); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index c861b21..78128ff 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -211,20 +211,25 @@ void AffineDataCopyGeneration::runOnFunction() { for (auto &block : f) runOnBlock(&block, copyNests); - // Promote any single iteration loops in the copy nests. + // Promote any single iteration loops in the copy nests and collect + // load/stores to simplify. + SmallVector copyOps; for (auto nest : copyNests) - nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); + // With a post order walk, the erasure of loops does not affect + // continuation of the walk or the collection of load/store ops. + nest->walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) + promoteIfSingleIteration(forOp); + else if (isa(op) || isa(op)) + copyOps.push_back(op); + }); // Promoting single iteration loops could lead to simplification of - // load's/store's. We will run canonicalization patterns on load/stores. - // TODO: this whole function load/store canonicalization should be replaced by - // canonicalization that is limited to only the load/store ops - // introduced/touched by this pass (those inside 'copyNests'). This would be - // possible once the necessary support is available in the pattern rewriter. - if (!copyNests.empty()) { - OwningRewritePatternList patterns; - AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); - AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(f, std::move(patterns)); - } + // contained load's/store's, and the latter could anyway also be + // canonicalized. + OwningRewritePatternList patterns; + AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); + AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); + for (auto op : copyOps) + applyOpPatternsAndFold(op, std::move(patterns)); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index 0df4ea0..fada39a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -6,14 +6,16 @@ // //===----------------------------------------------------------------------===// // -// This file implements a pass to simplify affine structures. +// This file implements a pass to simplify affine structures in operations. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/Utils.h" #define DEBUG_TYPE "simplify-affine-structure" @@ -77,13 +79,22 @@ mlir::createSimplifyAffineStructuresPass() { void SimplifyAffineStructures::runOnFunction() { auto func = getFunction(); simplifiedAttributes.clear(); - func.walk([&](Operation *opInst) { - for (auto attr : opInst->getAttrs()) { + OwningRewritePatternList patterns; + AffineForOp::getCanonicalizationPatterns(patterns, func.getContext()); + AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext()); + AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); + func.walk([&](Operation *op) { + for (auto attr : op->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(opInst, attr.first, mapAttr); + simplifyAndUpdateAttribute(op, attr.first, mapAttr); else if (auto setAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(opInst, attr.first, setAttr); + simplifyAndUpdateAttribute(op, attr.first, setAttr); } + + // The simplification of the attribute will likely simplify the op. Try to + // fold / apply canonicalization patterns when we have affine dialect ops. + if (isa(op) || isa(op) || isa(op)) + applyOpPatternsAndFold(op, patterns); }); // Turn memrefs' non-identity layouts maps into ones with identity. Collect diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 53c8e9f..256c134 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -26,6 +26,10 @@ using namespace mlir; /// The max number of iterations scanning for pattern match. static unsigned maxPatternMatchIterations = 10; +//===----------------------------------------------------------------------===// +// GreedyPatternRewriteDriver +//===----------------------------------------------------------------------===// + namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly /// applies the locally optimal patterns in a roughly "bottom up" way. @@ -37,8 +41,6 @@ public: worklist.reserve(64); } - /// Perform the rewrites while folding and erasing any dead ops. Return true - /// if the rewrite converges in `maxIterations`. bool simplify(MutableArrayRef regions, int maxIterations); void addToWorklist(Operation *op) { @@ -248,3 +250,112 @@ bool mlir::applyPatternsAndFoldGreedily( }); return converged; } + +//===----------------------------------------------------------------------===// +// OpPatternRewriteDriver +//===----------------------------------------------------------------------===// + +namespace { +/// This is a simple driver for the PatternMatcher to apply patterns and perform +/// folding on a single op. It repeatedly applies locally optimal patterns. +class OpPatternRewriteDriver : public PatternRewriter { +public: + explicit OpPatternRewriteDriver(MLIRContext *ctx, + const OwningRewritePatternList &patterns) + : PatternRewriter(ctx), matcher(patterns), folder(ctx) {} + + bool simplifyLocally(Operation *op, int maxIterations, bool &erased); + + /// No additional action needed other than inserting the op. + Operation *insert(Operation *op) override { return OpBuilder::insert(op); } + + // These are hooks implemented for PatternRewriter. +protected: + /// If an operation is about to be removed, mark it so that we can let clients + /// know. + void notifyOperationRemoved(Operation *op) override { + opErasedViaPatternRewrites = true; + } + + // When a root is going to be replaced, its removal will be notified as well. + // So there is nothing to do here. + void notifyRootReplaced(Operation *op) override {} + +private: + /// The low-level pattern matcher. + RewritePatternMatcher matcher; + + /// Non-pattern based folder for operations. + OperationFolder folder; + + /// Set to true if the operation has been erased via pattern rewrites. + bool opErasedViaPatternRewrites = false; +}; + +} // anonymous namespace + +/// Performs the rewrites and folding only on `op`. The simplification converges +/// if the op is erased as a result of being folded, replaced, or dead, or no +/// more changes happen in an iteration. Returns true if the rewrite converges +/// in `maxIterations`. `erased` is set to true if `op` gets erased. +bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations, + bool &erased) { + bool changed = false; + erased = false; + opErasedViaPatternRewrites = false; + int i = 0; + // Iterate until convergence or until maxIterations. Deletion of the op as + // a result of being dead or folded is convergence. + do { + // If the operation is trivially dead - remove it. + if (isOpTriviallyDead(op)) { + op->erase(); + erased = true; + return true; + } + + // Try to fold this op. + bool inPlaceUpdate; + if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr, + /*preReplaceAction=*/nullptr, + &inPlaceUpdate))) { + changed = true; + if (!inPlaceUpdate) { + erased = true; + return true; + } + } + + // Make sure that any new operations are inserted at this point. + setInsertionPoint(op); + + // Try to match one of the patterns. The rewriter is automatically + // notified of any necessary changes, so there is nothing else to do here. + changed |= matcher.matchAndRewrite(op, *this); + if ((erased = opErasedViaPatternRewrites)) + return true; + } while (changed && ++i < maxIterations); + + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. + return !changed; +} + +/// Rewrites only `op` using the supplied canonicalization patterns and +/// folding. `erased` is set to true if the op is erased as a result of being +/// folded, replaced, or dead. +bool mlir::applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool *erased) { + // Start the pattern driver. + OpPatternRewriteDriver driver(op->getContext(), patterns); + bool opErased; + bool converged = + driver.simplifyLocally(op, maxPatternMatchIterations, opErased); + if (erased) + *erased = opErased; + LLVM_DEBUG(if (!converged) { + llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " + << maxPatternMatchIterations << " times"; + }); + return converged; +} diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 72f889e..9fe9643 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/RegionUtils.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" @@ -312,9 +313,19 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp, opGroupQueue, /*offset=*/0, forOp, b); lbShift = d * step; } - if (!prologue && res) - prologue = res; - epilogue = res; + + if (res) { + // Simplify/canonicalize the affine.for. + OwningRewritePatternList patterns; + AffineForOp::getCanonicalizationPatterns(patterns, res.getContext()); + bool erased; + applyOpPatternsAndFold(res, std::move(patterns), &erased); + + if (!erased && !prologue) + prologue = res; + if (!erased) + epilogue = res; + } } else { // Start of first interval. lbShift = d * step; @@ -694,7 +705,8 @@ bool mlir::isValidLoopInterchangePermutation(ArrayRef loops, } /// Return true if `loops` is a perfect nest. -static bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef loops) { +static bool LLVM_ATTRIBUTE_UNUSED +isPerfectlyNested(ArrayRef loops) { auto outerLoop = loops.front(); for (auto loop : loops.drop_front()) { auto parentForOp = dyn_cast(loop.getParentOp()); diff --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir index 52c60d7..97d64a6 100644 --- a/mlir/test/Dialect/Affine/affine-data-copy.mlir +++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir @@ -216,7 +216,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> { return %A : memref<4096xf32> } // CHECK: affine.for %[[IV1:.*]] = 0 to 4096 step 100 -// CHECK-NEXT: %[[BUF:.*]] = alloc() : memref<100xf32> +// CHECK: %[[BUF:.*]] = alloc() : memref<100xf32> // CHECK-NEXT: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) { // CHECK-NEXT: affine.load %{{.*}}[%[[IV2]]] : memref<4096xf32> // CHECK-NEXT: affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32> @@ -226,7 +226,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> { // CHECK-NEXT: mulf // CHECK-NEXT: affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32> // CHECK-NEXT: } -// CHECK-NEXT: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) { +// CHECK: affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) { // CHECK-NEXT: affine.load %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32> // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[IV2]]] : memref<4096xf32> // CHECK-NEXT: } @@ -239,8 +239,8 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> { // with multi-level tiling when the tile sizes used don't divide loop trip // counts. -#lb = affine_map<(d0, d1) -> (d0 * 512, d1 * 6)> -#ub = affine_map<(d0, d1) -> (d0 * 512 + 512, d1 * 6 + 6)> +#lb = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)> +#ub = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)> // CHECK-DAG: #[[LB:.*]] = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)> // CHECK-DAG: #[[UB:.*]] = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)> @@ -250,7 +250,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> { // CHECK-SAME: [[j:arg[0-9]+]] func @max_lower_bound(%M: memref<2048x516xf64>, %i : index, %j : index) { affine.for %ii = 0 to 2048 { - affine.for %jj = max #lb(%i, %j) to min #ub(%i, %j) { + affine.for %jj = max #lb()[%i, %j] to min #ub()[%i, %j] { affine.load %M[%ii, %jj] : memref<2048x516xf64> } } diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir index 832d723..49fa339 100644 --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -1,19 +1,19 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -simplify-affine-structures | FileCheck %s -// CHECK-DAG: #[[SET_EMPTY_2D:.*]] = affine_set<(d0, d1) : (1 == 0)> +// CHECK-DAG: #[[SET_EMPTY:.*]] = affine_set<() : (1 == 0)> // CHECK-DAG: #[[SET_2D:.*]] = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0)> -// CHECK-DAG: #[[SET_EMPTY_2D_2S:.*]] = affine_set<(d0, d1)[s0, s1] : (1 == 0)> -// CHECK-DAG: #[[SET_2D_2S:.*]] = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)> -// CHECK-DAG: #[[SET_EMPTY_1D:.*]] = affine_set<(d0) : (1 == 0)> -// CHECK-DAG: #[[SET_EMPTY_1D_2S:.*]] = affine_set<(d0)[s0, s1] : (1 == 0)> -// CHECK-DAG: #[[SET_EMPTY_3D:.*]] = affine_set<(d0, d1, d2) : (1 == 0)> +// CHECK-DAG: #[[SET_7_11:.*]] = affine_set<(d0, d1) : (d0 * 7 + d1 * 5 + 88 == 0, d0 * 5 - d1 * 11 + 60 == 0, d0 * 11 + d1 * 7 - 24 == 0, d0 * 7 + d1 * 5 + 88 == 0)> + +// An external function that we will use in bodies to avoid DCE. +func @external() -> () // CHECK-LABEL: func @test_gaussian_elimination_empty_set0() { func @test_gaussian_elimination_empty_set0() { affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: [[SET_EMPTY_2D]](%arg0, %arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (2 == 0)>(%arg0, %arg1) { + call @external() : () -> () } } } @@ -24,8 +24,9 @@ func @test_gaussian_elimination_empty_set0() { func @test_gaussian_elimination_empty_set1() { affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: [[SET_EMPTY_2D]](%arg0, %arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (1 >= 0, -1 >= 0)> (%arg0, %arg1) { + call @external() : () -> () } } } @@ -38,6 +39,7 @@ func @test_gaussian_elimination_non_empty_set2() { affine.for %arg1 = 1 to 100 { // CHECK: #[[SET_2D]](%arg0, %arg1) affine.if affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)>(%arg0, %arg1) { + call @external() : () -> () } } } @@ -50,8 +52,9 @@ func @test_gaussian_elimination_empty_set3() { %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[SET_EMPTY_2D_2S]](%arg0, %arg1)[%c7, %c11] + // CHECK: #[[SET_EMPTY]]() affine.if affine_set<(d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0)>(%arg0, %arg1)[%c7, %c11] { + call @external() : () -> () } } } @@ -70,8 +73,9 @@ func @test_gaussian_elimination_non_empty_set4() { %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[SET_2D_2S]](%arg0, %arg1)[%c7, %c11] + // CHECK: #[[SET_7_11]](%arg0, %arg1) affine.if #set_2d_non_empty(%arg0, %arg1)[%c7, %c11] { + call @external() : () -> () } } } @@ -79,7 +83,6 @@ func @test_gaussian_elimination_non_empty_set4() { } // Add invalid constraints to previous non-empty set to make it empty. -// Set for test case: test_gaussian_elimination_empty_set5 #set_2d_empty = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, @@ -92,8 +95,9 @@ func @test_gaussian_elimination_empty_set5() { %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[SET_EMPTY_2D_2S]](%arg0, %arg1)[%c7, %c11] + // CHECK: #[[SET_EMPTY]]() affine.if #set_2d_empty(%arg0, %arg1)[%c7, %c11] { + call @external() : () -> () } } } @@ -147,6 +151,7 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i affine.for %arg4 = 1 to 10 { affine.for %arg5 = 1 to 100 { affine.if #set_fuzz_virus(%arg4, %arg5, %arg0, %arg1, %arg2, %arg3) { + call @external() : () -> () } } } @@ -157,33 +162,33 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i func @test_empty_set(%N : index) { affine.for %i = 0 to 10 { affine.for %j = 0 to 10 { - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)>(%i, %j) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) { "bar"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_1D_2S]](%arg1)[%arg0, %arg0] + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, -s0 >= 0)>(%i)[%N, %N] { "bar"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_3D]](%arg1, %arg2, %arg0) + // CHECK: affine.if #[[SET_EMPTY]]() // The set below implies d0 = d1; so d1 >= d0, but d0 >= d1 + 1. affine.if affine_set<(d0, d1, d2) : (d0 - d1 == 0, d2 - d0 >= 0, d0 - d1 - 1 >= 0)>(%i, %j, %N) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() // The set below has rational solutions but no integer solutions; GCD test catches it. affine.if affine_set<(d0, d1) : (d0*2 -d1*2 - 1 == 0, d0 >= 0, -d0 + 100 >= 0, d1 >= 0, -d1 + 100 >= 0)>(%i, %j) { "foo"() : () -> () } - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (d1 == 0, d0 - 1 >= 0, - d0 - 1 >= 0)>(%i, %j) { "foo"() : () -> () } @@ -193,12 +198,12 @@ func @test_empty_set(%N : index) { affine.for %k = 0 to 10 { affine.for %l = 0 to 10 { // Empty because no multiple of 8 lies between 4 and 7. - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)>(%k) { "foo"() : () -> () } // Same as above but with equalities and inequalities. - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (d0 - 4*d1 == 0, 4*d1 - 5 >= 0, -4*d1 + 7 >= 0)>(%k, %l) { "foo"() : () -> () } @@ -206,12 +211,12 @@ func @test_empty_set(%N : index) { // 8*d1 here is a multiple of 4, and so can't lie between 9 and 11. GCD // tightening will tighten constraints to 4*d0 + 8*d1 >= 12 and 4*d0 + // 8*d1 <= 8; hence infeasible. - // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1) : (4*d0 + 8*d1 - 9 >= 0, -4*d0 - 8*d1 + 11 >= 0)>(%k, %l) { "foo"() : () -> () } // Same as above but with equalities added into the mix. - // CHECK: affine.if #[[SET_EMPTY_3D]](%arg1, %arg1, %arg2) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0, d1, d2) : (d0 - 4*d2 == 0, d0 + 8*d1 - 9 >= 0, -d0 - 8*d1 + 11 >= 0)>(%k, %k, %l) { "foo"() : () -> () } @@ -219,7 +224,7 @@ func @test_empty_set(%N : index) { } affine.for %m = 0 to 10 { - // CHECK: affine.if #[[SET_EMPTY_1D]](%arg{{[0-9]+}}) + // CHECK: affine.if #[[SET_EMPTY]]() affine.if affine_set<(d0) : (d0 mod 2 - 3 == 0)> (%m) { "foo"() : () -> () } @@ -230,20 +235,39 @@ func @test_empty_set(%N : index) { // ----- -// CHECK-DAG: #[[SET_2D:.*]] = affine_set<(d0, d1) : (d0 >= 0, -d0 + 50 >= 0) -// CHECK-DAG: #[[SET_EMPTY:.*]] = affine_set<(d0, d1) : (1 == 0) -// CHECK-DAG: #[[SET_UNIV:.*]] = affine_set<(d0, d1) : (0 == 0) +// An external function that we will use in bodies to avoid DCE. +func @external() -> () + +// CHECK-DAG: #[[SET:.*]] = affine_set<()[s0] : (s0 >= 0, -s0 + 50 >= 0) +// CHECK-DAG: #[[EMPTY_SET:.*]] = affine_set<() : (1 == 0) +// CHECK-DAG: #[[UNIV_SET:.*]] = affine_set<() : (0 == 0) // CHECK-LABEL: func @simplify_set func @simplify_set(%a : index, %b : index) { - // CHECK: affine.if #[[SET_2D]] + // CHECK: affine.if #[[SET]] affine.if affine_set<(d0, d1) : (d0 - d1 + d1 + d0 >= 0, 2 >= 0, d0 >= 0, -d0 + 50 >= 0, -d0 + 100 >= 0)>(%a, %b) { + call @external() : () -> () } - // CHECK: affine.if #[[SET_EMPTY]] + // CHECK: affine.if #[[EMPTY_SET]] affine.if affine_set<(d0, d1) : (d0 mod 2 - 1 == 0, d0 - 2 * (d0 floordiv 2) == 0)>(%a, %b) { + call @external() : () -> () } - // CHECK: affine.if #[[SET_UNIV]] + // CHECK: affine.if #[[UNIV_SET]] affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%a, %b) { + call @external() : () -> () } return } + +// ----- + +// CHECK-DAG: -> (s0 * 2 + 1) + +// Test "op local" simplification on affine.apply. DCE on addi will not happen. +func @affine.apply(%N : index) { + %v = affine.apply affine_map<(d0, d1) -> (d0 + d1 + 1)>(%N, %N) + addi %v, %v : index + // CHECK: affine.apply #map{{.*}}()[%arg0] + // CHECK-NEXT: addi + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp index 6c8b546..b6df06e 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -88,16 +88,35 @@ void TestAffineDataCopy::runOnFunction() { generateCopyForMemRegion(region, loopNest, copyOptions, result); } - // Promote any single iteration loops in the copy nests. + // Promote any single iteration loops in the copy nests and simplify + // load/stores. + SmallVector copyOps; for (auto nest : copyNests) - nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); }); - - // Promoting single iteration loops could lead to simplification - // of load's/store's. We will run the canonicalization patterns again. + // With a post order walk, the erasure of loops does not affect + // continuation of the walk or the collection of load/store ops. + nest->walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) + promoteIfSingleIteration(forOp); + else if (auto loadOp = dyn_cast(op)) + copyOps.push_back(loadOp); + else if (auto storeOp = dyn_cast(op)) + copyOps.push_back(storeOp); + }); + + // Promoting single iteration loops could lead to simplification of + // generated load's/store's, and the latter could anyway also be + // canonicalized. OwningRewritePatternList patterns; - AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); - AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + for (auto op : copyOps) { + patterns.clear(); + if (isa(op)) { + AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); + } else { + assert(isa(op) && "expected affine store op"); + AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); + } + applyOpPatternsAndFold(op, std::move(patterns)); + } } namespace mlir { -- 2.7.4