//
// `var_q = var_n floordiv divisor`.
//
+// First 'num' dimensional variables starting at 'offset' are
+// derived/to-be-derived in terms of the remaining variables. The remaining
+// variables are assigned trivial affine expressions in `memo`. For example,
+// memo is initilized as follows for a `cst` with 5 dims, when offset=2, num=2:
+// memo ==> d0 d1 . . d2 ...
+// cst ==> c0 c1 c2 c3 c4 ...
+//
// Returns true if the above mod or floordiv are detected, updating 'memo' with
// these new expressions. Returns false otherwise.
static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos,
- int64_t lbConst, int64_t ubConst,
- SmallVectorImpl<AffineExpr> &memo,
- MLIRContext *context) {
+ unsigned offset, unsigned num, int64_t lbConst,
+ int64_t ubConst, MLIRContext *context,
+ SmallVectorImpl<AffineExpr> &memo) {
assert(pos < cst.getNumVars() && "invalid position");
// Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
// Express `var_r` as `var_n % divisor` and store the expression in `memo`.
if (quotientCount >= 1) {
- auto ub = cst.getConstantBound64(BoundType::UB, dimExpr.getPosition());
+ // Find the column corresponding to `dimExpr`. `num` columns starting at
+ // `offset` correspond to previously unknown variables. The column
+ // corresponding to the trivially known `dimExpr` can be on either side
+ // of these.
+ unsigned dimExprPos = dimExpr.getPosition();
+ unsigned dimExprCol = dimExprPos < offset ? dimExprPos : dimExprPos + num;
+ auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol);
// If `var_n` has an upperbound that is less than the divisor, mod can be
// eliminated altogether.
if (ub && *ub < divisor)
// Detect a variable as modulo of another variable w.r.t a
// constant.
- if (detectAsMod(*this, pos, *lbConst, *ubConst, memo, context)) {
+ if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context,
+ memo)) {
changed = true;
continue;
}
// PRODUCER-CONSUMER-NEXT: }
return
}
+
+// -----
+
+// PRODUCER-CONSUMER-LABEL: @fuse_higher_dim_nest_into_lower_dim_nest
+func.func @fuse_higher_dim_nest_into_lower_dim_nest() {
+ %A = memref.alloc() : memref<8x12x128x64xf32>
+ %B = memref.alloc() : memref<8x128x12x64xf32>
+ affine.for %arg205 = 0 to 8 {
+ affine.for %arg206 = 0 to 128 {
+ affine.for %arg207 = 0 to 12 {
+ affine.for %arg208 = 0 to 64 {
+ %a = affine.load %A[%arg205, %arg207, %arg206, %arg208] : memref<8x12x128x64xf32>
+ affine.store %a, %B[%arg205, %arg206, %arg207, %arg208] : memref<8x128x12x64xf32>
+ }
+ }
+ }
+ }
+ %C = memref.alloc() : memref<8x128x768xf16>
+ affine.for %arg205 = 0 to 8 {
+ affine.for %arg206 = 0 to 128 {
+ affine.for %arg207 = 0 to 768 {
+ %b = affine.load %B[%arg205, %arg206, %arg207 floordiv 64, %arg207 mod 64] : memref<8x128x12x64xf32>
+ %c = arith.truncf %b : f32 to f16
+ affine.store %c, %C[%arg205, %arg206, %arg207] : memref<8x128x768xf16>
+ }
+ }
+ }
+
+ // Check that fusion happens into the innermost loop of the consumer.
+ // PRODUCER-CONSUMER: affine.for
+ // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 128
+ // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 768
+ // PRODUCER-CONSUMER-NOT: affine.for
+ // PRODUCER-CONSUMER: return
+ return
+}