[Affine-fusion] Fix a bug in mod detection
authorVinayaka Bandishti <vinayaka@polymagelabs.com>
Mon, 5 Jun 2023 05:17:42 +0000 (10:47 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Mon, 5 Jun 2023 05:17:48 +0000 (10:47 +0530)
Fix a bug in detecting unknown ids as mods of known ids that was
preventing certain fusions.

While at this, fix the function signature of `detectAsMod` function to
have output as the last argument.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D152055

mlir/lib/Analysis/FlatLinearValueConstraints.cpp
mlir/test/Transforms/loop-fusion-4.mlir

index 348ffbf..2dbb2e6 100644 (file)
@@ -221,12 +221,19 @@ LogicalResult FlatLinearConstraints::composeMatchingMap(AffineMap other) {
 //
 // `var_q = var_n floordiv divisor`.
 //
+// First 'num' dimensional variables starting at 'offset' are
+// derived/to-be-derived in terms of the remaining variables. The remaining
+// variables are assigned trivial affine expressions in `memo`. For example,
+// memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2:
+// memo ==>  d0  d1  .   .   d2 ...
+// cst  ==>  c0  c1  c2  c3  c4 ...
+//
 // Returns true if the above mod or floordiv are detected, updating 'memo' with
 // these new expressions. Returns false otherwise.
 static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
-                        int64_t lbConst, int64_t ubConst,
-                        SmallVectorImpl<AffineExpr> &memo,
-                        MLIRContext *context) {
+                        unsigned offset, unsigned num, int64_t lbConst,
+                        int64_t ubConst, MLIRContext *context,
+                        SmallVectorImpl<AffineExpr> &memo) {
   assert(pos < cst.getNumVars() && "invalid position");
 
   // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
@@ -308,7 +315,13 @@ static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
 
     // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
     if (quotientCount >= 1) {
-      auto ub = cst.getConstantBound64(BoundType::UB, dimExpr.getPosition());
+      // Find the column corresponding to `dimExpr`. `num` columns starting at
+      // `offset` correspond to previously unknown variables. The column
+      // corresponding to the trivially known `dimExpr` can be on either side
+      // of these.
+      unsigned dimExprPos = dimExpr.getPosition();
+      unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num;
+      auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol);
       // If `var_n` has an upperbound that is less than the divisor, mod can be
       // eliminated altogether.
       if (ub && *ub < divisor)
@@ -499,7 +512,8 @@ void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
 
         // Detect a variable as modulo of another variable w.r.t a
         // constant.
-        if (detectAsMod(*this, pos, *lbConst, *ubConst, memo, context)) {
+        if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context,
+                        memo)) {
           changed = true;
           continue;
         }
index 2d4a27c..3fc31ad 100644 (file)
@@ -190,3 +190,39 @@ func.func @fusion_for_multiple_blocks() {
   // PRODUCER-CONSUMER-NEXT: }
   return
 }
+
+// -----
+
+// PRODUCER-CONSUMER-LABEL: @fuse_higher_dim_nest_into_lower_dim_nest
+func.func @fuse_higher_dim_nest_into_lower_dim_nest() {
+  %A = memref.alloc() : memref<8x12x128x64xf32>
+  %B = memref.alloc() : memref<8x128x12x64xf32>
+  affine.for %arg205 = 0 to 8 {
+    affine.for %arg206 = 0 to 128 {
+      affine.for %arg207 = 0 to 12 {
+        affine.for %arg208 = 0 to 64 {
+          %a = affine.load %A[%arg205, %arg207, %arg206, %arg208] : memref<8x12x128x64xf32>
+          affine.store %a, %B[%arg205, %arg206, %arg207, %arg208] : memref<8x128x12x64xf32>
+        }
+      }
+    }
+  }
+  %C = memref.alloc() : memref<8x128x768xf16>
+  affine.for %arg205 = 0 to 8 {
+    affine.for %arg206 = 0 to 128 {
+      affine.for %arg207 = 0 to 768 {
+        %b = affine.load %B[%arg205, %arg206, %arg207 floordiv 64, %arg207 mod 64] : memref<8x128x12x64xf32>
+        %c = arith.truncf %b : f32 to f16
+        affine.store %c, %C[%arg205, %arg206, %arg207] : memref<8x128x768xf16>
+      }
+    }
+  }
+
+  // Check that fusion happens into the innermost loop of the consumer.
+  // PRODUCER-CONSUMER:      affine.for
+  // PRODUCER-CONSUMER-NEXT:   affine.for %{{.*}} = 0 to 128
+  // PRODUCER-CONSUMER-NEXT:     affine.for %{{.*}} = 0 to 768
+  // PRODUCER-CONSUMER-NOT:  affine.for
+  // PRODUCER-CONSUMER:      return
+  return
+}