From c45c96250b422fc59cd3b3454fddf737512cf838 Mon Sep 17 00:00:00 2001 From: Vinayaka Bandishti Date: Mon, 5 Jun 2023 10:47:42 +0530 Subject: [PATCH] [Affine-fusion] Fix a bug in mod detection Fix a bug in detecting unknown ids as mods of known ids that was preventing certain fusions. While at this, fix the function signature of `detectAsMod` function to have output as the last argument. Reviewed By: bondhugula Differential Revision: https://reviews.llvm.org/D152055 --- mlir/lib/Analysis/FlatLinearValueConstraints.cpp | 24 ++++++++++++---- mlir/test/Transforms/loop-fusion-4.mlir | 36 ++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 348ffbf..2dbb2e6 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -221,12 +221,19 @@ LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) { // // `var_q = var_n floordiv divisor`. // +// First 'num' dimensional variables starting at 'offset' are +// derived/to-be-derived in terms of the remaining variables. The remaining +// variables are assigned trivial affine expressions in `memo`. For example, +// memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2: +// memo ==> d0 d1 . . d2 ... +// cst ==> c0 c1 c2 c3 c4 ... +// // Returns true if the above mod or floordiv are detected, updating 'memo' with // these new expressions. Returns false otherwise. static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos, - int64_t lbConst, int64_t ubConst, - SmallVectorImpl &memo, - MLIRContext *context) { + unsigned offset, unsigned num, int64_t lbConst, + int64_t ubConst, MLIRContext *context, + SmallVectorImpl &memo) { assert(pos < cst.getNumVars() && "invalid position"); // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can @@ -308,7 +315,13 @@ static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos, // Express `var_r` as `var_n % divisor` and store the expression in `memo`. if (quotientCount >= 1) { - auto ub = cst.getConstantBound64(BoundType::UB, dimExpr.getPosition()); + // Find the column corresponding to `dimExpr`. `num` columns starting at + // `offset` correspond to previously unknown variables. The column + // corresponding to the trivially known `dimExpr` can be on either side + // of these. + unsigned dimExprPos = dimExpr.getPosition(); + unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num; + auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol); // If `var_n` has an upperbound that is less than the divisor, mod can be // eliminated altogether. if (ub && *ub < divisor) @@ -499,7 +512,8 @@ void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num, // Detect a variable as modulo of another variable w.r.t a // constant. - if (detectAsMod(*this, pos, *lbConst, *ubConst, memo, context)) { + if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context, + memo)) { changed = true; continue; } diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir index 2d4a27c..3fc31ad 100644 --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -190,3 +190,39 @@ func.func @fusion_for_multiple_blocks() { // PRODUCER-CONSUMER-NEXT: } return } + +// ----- + +// PRODUCER-CONSUMER-LABEL: @fuse_higher_dim_nest_into_lower_dim_nest +func.func @fuse_higher_dim_nest_into_lower_dim_nest() { + %A = memref.alloc() : memref<8x12x128x64xf32> + %B = memref.alloc() : memref<8x128x12x64xf32> + affine.for %arg205 = 0 to 8 { + affine.for %arg206 = 0 to 128 { + affine.for %arg207 = 0 to 12 { + affine.for %arg208 = 0 to 64 { + %a = affine.load %A[%arg205, %arg207, %arg206, %arg208] : memref<8x12x128x64xf32> + affine.store %a, %B[%arg205, %arg206, %arg207, %arg208] : memref<8x128x12x64xf32> + } + } + } + } + %C = memref.alloc() : memref<8x128x768xf16> + affine.for %arg205 = 0 to 8 { + affine.for %arg206 = 0 to 128 { + affine.for %arg207 = 0 to 768 { + %b = affine.load %B[%arg205, %arg206, %arg207 floordiv 64, %arg207 mod 64] : memref<8x128x12x64xf32> + %c = arith.truncf %b : f32 to f16 + affine.store %c, %C[%arg205, %arg206, %arg207] : memref<8x128x768xf16> + } + } + } + + // Check that fusion happens into the innermost loop of the consumer. + // PRODUCER-CONSUMER: affine.for + // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 128 + // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 768 + // PRODUCER-CONSUMER-NOT: affine.for + // PRODUCER-CONSUMER: return + return +} -- 2.7.4