[MLIR] Fold away divs and mods in affine ops with operand info
authorUday Bondhugula <uday@polymagelabs.com>
Fri, 10 Feb 2023 08:05:04 +0000 (13:35 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Fri, 10 Feb 2023 08:09:56 +0000 (13:39 +0530)
Fold away divs and mods in affine maps exploiting operand info during
canonicalization. This simplifies affine map applications such as the ones
below:

```
// Simple ones.
affine.for %i = 0 to 32 {
  affine.load %A[%i floordiv 32]
  affine.load %A[%i mod 32]
  affine.load %A[2 * %i floordiv 64]
  affine.load %A[(%i mod 16) floordiv 16]
  ...
}

// Others.
 affine.for %i = -8 to 32 {
   // Will be simplified %A[0].
   affine.store %cst, %A[2 + (%i - 96) floordiv 64] : memref<64xf32>
}
```

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D143456

mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir

index 4481c14..284e099 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -673,8 +674,168 @@ static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
   return false;
 }
 
+/// Gets the constant lower bound on an `iv`.
+static std::optional<int64_t> getLowerBound(Value iv) {
+  AffineForOp forOp = getForInductionVarOwner(iv);
+  if (forOp && forOp.hasConstantLowerBound())
+    return forOp.getConstantLowerBound();
+  return std::nullopt;
+}
+
+/// Gets the constant upper bound on an affine.for `iv`.
+static Optional<int64_t> getUpperBound(Value iv) {
+  AffineForOp forOp = getForInductionVarOwner(iv);
+  if (!forOp || !forOp.hasConstantUpperBound())
+    return std::nullopt;
+
+  // If its lower bound is also known, we can get a more precise bound
+  // whenever the step is not one.
+  if (forOp.hasConstantLowerBound()) {
+    return forOp.getConstantUpperBound() - 1 -
+           (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
+               forOp.getStep();
+  }
+  return forOp.getConstantUpperBound() - 1;
+}
+
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return None if such a bound can't
+/// be computed. This method only handles simple sum of product expressions
+/// (w.r.t constant coefficients) so as to not depend on anything heavyweight in
+/// `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 + ... + c_n are
+/// handled. Expressions involving floordiv, ceildiv, mod or semi-affine ones
+/// will lead a none being returned.
+static std::optional<int64_t>
+getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                ArrayRef<Optional<int64_t>> constLowerBounds,
+                ArrayRef<Optional<int64_t>> constUpperBounds, bool isUpper) {
+  // Handle divs and mods.
+  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+    // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+    // can compute an upper bound.
+    if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!rhsConst || rhsConst.getValue() < 1)
+        return std::nullopt;
+      auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                   constLowerBounds, constUpperBounds, isUpper);
+      if (!bound)
+        return std::nullopt;
+      return mlir::floorDiv(*bound, rhsConst.getValue());
+    }
+    if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        auto bound =
+            getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                            constLowerBounds, constUpperBounds, isUpper);
+        if (!bound)
+          return std::nullopt;
+        return mlir::ceilDiv(*bound, rhsConst.getValue());
+      }
+      return std::nullopt;
+    }
+    if (binOpExpr.getKind() == AffineExprKind::Mod) {
+      // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+      // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+      // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+      auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (rhsConst && rhsConst.getValue() >= 1) {
+        int64_t rhsConstVal = rhsConst.getValue();
+        auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds,
+                                  /*isUpper=*/false);
+        auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+                                  constLowerBounds, constUpperBounds, isUpper);
+        if (ub && lb &&
+            floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+          return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+        return isUpper ? rhsConstVal - 1 : 0;
+      }
+    }
+  }
+  // Flatten the expression.
+  SimpleAffineExprFlattener flattener(numDims, numSymbols);
+  flattener.walkPostOrder(expr);
+  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+  // TODO: Handle local variables. We can get hold of flattener.localExprs and
+  // get bound on the local expr recursively.
+  if (flattener.numLocals > 0)
+    return std::nullopt;
+  int64_t bound = 0;
+  // Substitute the constant lower or upper bound for the dimensional or
+  // symbolic input depending on `isUpper` to determine the bound.
+  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+    if (flattenedExpr[i] > 0) {
+      auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    } else if (flattenedExpr[i] < 0) {
+      auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+      if (!constBound)
+        return std::nullopt;
+      bound += *constBound * flattenedExpr[i];
+    }
+  }
+  // Constant term.
+  bound += flattenedExpr.back();
+  return bound;
+}
+
+/// Determine a constant upper bound for `expr` if one exists while exploiting
+/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
+/// is guaranteed to be less than or equal to it.
+static Optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
+                                       unsigned numSymbols,
+                                       ArrayRef<Value> operands) {
+  // Get the constant lower or upper bounds on the operands.
+  SmallVector<Optional<int64_t>> constLowerBounds, constUpperBounds;
+  constLowerBounds.reserve(operands.size());
+  constUpperBounds.reserve(operands.size());
+  for (Value operand : operands) {
+    constLowerBounds.push_back(getLowerBound(operand));
+    constUpperBounds.push_back(getUpperBound(operand));
+  }
+
+  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+    return constExpr.getValue();
+
+  return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
+                         constUpperBounds,
+                         /*isUpper=*/true);
+}
+
+/// Determine a constant lower bound for `expr` if one exists while exploiting
+/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
+/// is guaranteed to be less than or equal to it.
+static Optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
+                                       unsigned numSymbols,
+                                       ArrayRef<Value> operands) {
+  // Get the constant lower or upper bounds on the operands.
+  SmallVector<Optional<int64_t>> constLowerBounds, constUpperBounds;
+  constLowerBounds.reserve(operands.size());
+  constUpperBounds.reserve(operands.size());
+  for (Value operand : operands) {
+    constLowerBounds.push_back(getLowerBound(operand));
+    constUpperBounds.push_back(getUpperBound(operand));
+  }
+
+  Optional<int64_t> lowerBound;
+  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+    lowerBound = constExpr.getValue();
+  } else {
+    lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
+                                 constUpperBounds,
+                                 /*isUpper=*/false);
+  }
+  return lowerBound;
+}
+
 /// Simplify `expr` while exploiting information from the values in `operands`.
-static void simplifyExprAndOperands(AffineExpr &expr,
+static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
+                                    unsigned numSymbols,
                                     ArrayRef<Value> operands) {
   // We do this only for certain floordiv/mod expressions.
   auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
@@ -684,13 +845,14 @@ static void simplifyExprAndOperands(AffineExpr &expr,
   // Simplify the child expressions first.
   AffineExpr lhs = binExpr.getLHS();
   AffineExpr rhs = binExpr.getRHS();
-  simplifyExprAndOperands(lhs, operands);
-  simplifyExprAndOperands(rhs, operands);
+  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
+  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
   expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
 
   binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
-  if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv &&
-                   binExpr.getKind() != AffineExprKind::Mod)) {
+  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
+                   expr.getKind() != AffineExprKind::CeilDiv &&
+                   expr.getKind() != AffineExprKind::Mod)) {
     return;
   }
 
@@ -703,16 +865,50 @@ static void simplifyExprAndOperands(AffineExpr &expr,
 
   int64_t rhsConstVal = rhsConst.getValue();
   // Undefined exprsessions aren't touched; IR can still be valid with them.
-  if (rhsConstVal == 0)
+  if (rhsConstVal <= 0)
     return;
 
-  AffineExpr quotientTimesDiv, rem;
-  int64_t divisor;
+  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
+  MLIRContext *context = expr.getContext();
+  std::optional<int64_t> lhsLbConst =
+      getLowerBound(lhs, numDims, numSymbols, operands);
+  std::optional<int64_t> lhsUbConst =
+      getUpperBound(lhs, numDims, numSymbols, operands);
+  if (lhsLbConst && lhsUbConst) {
+    int64_t lhsLbConstVal = *lhsLbConst;
+    int64_t lhsUbConstVal = *lhsUbConst;
+    // lhs floordiv c is a single value lhs is bounded in a range `c` that has
+    // the same quotient.
+    if (binExpr.getKind() == AffineExprKind::FloorDiv &&
+        floorDiv(lhsLbConstVal, rhsConstVal) ==
+            floorDiv(lhsUbConstVal, rhsConstVal)) {
+      expr =
+          getAffineConstantExpr(floorDiv(lhsLbConstVal, rhsConstVal), context);
+      return;
+    }
+    // lhs ceildiv c is a single value if the entire range has the same ceil
+    // quotient.
+    if (binExpr.getKind() == AffineExprKind::CeilDiv &&
+        ceilDiv(lhsLbConstVal, rhsConstVal) ==
+            ceilDiv(lhsUbConstVal, rhsConstVal)) {
+      expr =
+          getAffineConstantExpr(ceilDiv(lhsLbConstVal, rhsConstVal), context);
+      return;
+    }
+    // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
+    if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
+        lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
+      expr = lhs;
+      return;
+    }
+  }
 
   // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
   // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
   // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
   // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
+  AffineExpr quotientTimesDiv, rem;
+  int64_t divisor;
   if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
     if (rhsConstVal % divisor == 0 &&
         binExpr.getKind() == AffineExprKind::FloorDiv) {
@@ -745,7 +941,8 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
   SmallVector<AffineExpr> newResults;
   newResults.reserve(map.getNumResults());
   for (AffineExpr expr : map.getResults()) {
-    simplifyExprAndOperands(expr, operands);
+    simplifyExprAndOperands(expr, map.getNumDims(), map.getNumSymbols(),
+                            operands);
     newResults.push_back(expr);
   }
   map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
index ddebab4..d6ca503 100644 (file)
@@ -1170,8 +1170,8 @@ func.func @simplify_with_operands(%N: index, %A: memref<?x32xf32>) {
       "test.foo"(%x) : (f32) -> ()
 
       // %i is aligned at 32 boundary and %ii < 32.
-      // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 32]
-      %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 32] : memref<?x32xf32>
+      // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 16]
+      %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 16] : memref<?x32xf32>
       "test.foo"(%a) : (f32) -> ()
       // CHECK: affine.load %{{.*}}[%[[I]] floordiv 64, (%[[I]] + %[[II]]) mod 64]
       %b = affine.load %A[(%i + %ii) floordiv 64, (%i + %ii) mod 64] : memref<?x32xf32>
@@ -1202,6 +1202,66 @@ func.func @simplify_with_operands(%N: index, %A: memref<?x32xf32>) {
   return
 }
 
+// CHECK-LABEL: func @simplify_div_mod_with_operands
+func.func @simplify_div_mod_with_operands(%N: index, %A: memref<64xf32>, %unknown: index) {
+  // CHECK: affine.for %[[I:.*]] = 0 to 32
+  %cst = arith.constant 1.0 : f32
+  affine.for %i = 0 to 32 {
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+    affine.store %cst, %A[%i floordiv 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[1]
+    affine.store %cst, %A[(%i + 1) ceildiv 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%[[I]]]
+    affine.store %cst, %A[%i mod 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+    affine.store %cst, %A[2 * %i floordiv 64] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+    affine.store %cst, %A[(%i mod 16) floordiv 16] : memref<64xf32>
+
+    // The ones below can't be simplified.
+    affine.store %cst, %A[%i floordiv 16] : memref<64xf32>
+    affine.store %cst, %A[%i mod 16] : memref<64xf32>
+    affine.store %cst, %A[(%i mod 16) floordiv 15] : memref<64xf32>
+    affine.store %cst, %A[%i mod 31] : memref<64xf32>
+    // CHECK:      affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 16] : memref<64xf32>
+    // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 16] : memref<64xf32>
+    // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[(%{{.*}} mod 16) floordiv 15] : memref<64xf32>
+    // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 31] : memref<64xf32>
+  }
+
+  affine.for %i = -8 to 32 {
+    // Can't be simplified.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 32] : memref<64xf32>
+    affine.store %cst, %A[%i floordiv 32] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 32] : memref<64xf32>
+    affine.store %cst, %A[%i mod 32] : memref<64xf32>
+    // floordiv rounds toward -inf; (%i - 96) floordiv 64 will be -2.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+    affine.store %cst, %A[2 + (%i - 96) floordiv 64] : memref<64xf32>
+  }
+
+  // CHECK: affine.for %[[II:.*]] = 8 to 16
+  affine.for %i = 8 to 16 {
+    // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32>
+    affine.store %cst, %A[%i floordiv 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[2] : memref<64xf32>
+    affine.store %cst, %A[(%i + 1) ceildiv 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]] mod 8] : memref<64xf32>
+    affine.store %cst, %A[%i mod 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]]] : memref<64xf32>
+    affine.store %cst, %A[%i mod 32] : memref<64xf32>
+    // Upper bound on the mod 32 expression will be 15.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+    affine.store %cst, %A[(%i mod 32) floordiv 16] : memref<64xf32>
+    // Lower bound on the mod 16 expression will be 8.
+    // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32>
+    affine.store %cst, %A[(%i mod 16) floordiv 8] : memref<64xf32>
+    // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+    affine.store %cst, %A[(%unknown mod 16) floordiv 16] : memref<64xf32>
+  }
+  return
+}
+
 // -----
 
 //           CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>