#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
+#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallBitVector.h"
return false;
}
+/// Gets the constant lower bound on an `iv`.
+static std::optional<int64_t> getLowerBound(Value iv) {
+ AffineForOp forOp = getForInductionVarOwner(iv);
+ if (forOp && forOp.hasConstantLowerBound())
+ return forOp.getConstantLowerBound();
+ return std::nullopt;
+}
+
+/// Gets the constant upper bound on an affine.for `iv`.
+static Optional<int64_t> getUpperBound(Value iv) {
+ AffineForOp forOp = getForInductionVarOwner(iv);
+ if (!forOp || !forOp.hasConstantUpperBound())
+ return std::nullopt;
+
+ // If its lower bound is also known, we can get a more precise bound
+ // whenever the step is not one.
+ if (forOp.hasConstantLowerBound()) {
+ return forOp.getConstantUpperBound() - 1 -
+ (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
+ forOp.getStep();
+ }
+ return forOp.getConstantUpperBound() - 1;
+}
+
+/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
+/// the constant lower and upper bounds for its inputs provided in
+/// `constLowerBounds` and `constUpperBounds`. Return None if such a bound can't
+/// be computed. This method only handles simple sum of product expressions
+/// (w.r.t constant coefficients) so as to not depend on anything heavyweight in
+/// `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 + ... + c_n are
+/// handled. Expressions involving floordiv, ceildiv, mod or semi-affine ones
+/// will lead a none being returned.
+static std::optional<int64_t>
+getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+ ArrayRef<Optional<int64_t>> constLowerBounds,
+ ArrayRef<Optional<int64_t>> constUpperBounds, bool isUpper) {
+ // Handle divs and mods.
+ if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
+ // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
+ // can compute an upper bound.
+ if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!rhsConst || rhsConst.getValue() < 1)
+ return std::nullopt;
+ auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (!bound)
+ return std::nullopt;
+ return mlir::floorDiv(*bound, rhsConst.getValue());
+ }
+ if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (rhsConst && rhsConst.getValue() >= 1) {
+ auto bound =
+ getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (!bound)
+ return std::nullopt;
+ return mlir::ceilDiv(*bound, rhsConst.getValue());
+ }
+ return std::nullopt;
+ }
+ if (binOpExpr.getKind() == AffineExprKind::Mod) {
+ // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
+ // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
+ // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
+ auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (rhsConst && rhsConst.getValue() >= 1) {
+ int64_t rhsConstVal = rhsConst.getValue();
+ auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds,
+ /*isUpper=*/false);
+ auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
+ constLowerBounds, constUpperBounds, isUpper);
+ if (ub && lb &&
+ floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
+ return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
+ return isUpper ? rhsConstVal - 1 : 0;
+ }
+ }
+ }
+ // Flatten the expression.
+ SimpleAffineExprFlattener flattener(numDims, numSymbols);
+ flattener.walkPostOrder(expr);
+ ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+ // TODO: Handle local variables. We can get hold of flattener.localExprs and
+ // get bound on the local expr recursively.
+ if (flattener.numLocals > 0)
+ return std::nullopt;
+ int64_t bound = 0;
+ // Substitute the constant lower or upper bound for the dimensional or
+ // symbolic input depending on `isUpper` to determine the bound.
+ for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
+ if (flattenedExpr[i] > 0) {
+ auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
+ if (!constBound)
+ return std::nullopt;
+ bound += *constBound * flattenedExpr[i];
+ } else if (flattenedExpr[i] < 0) {
+ auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
+ if (!constBound)
+ return std::nullopt;
+ bound += *constBound * flattenedExpr[i];
+ }
+ }
+ // Constant term.
+ bound += flattenedExpr.back();
+ return bound;
+}
+
+/// Determine a constant upper bound for `expr` if one exists while exploiting
+/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
+/// is guaranteed to be less than or equal to it.
+static Optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
+ unsigned numSymbols,
+ ArrayRef<Value> operands) {
+ // Get the constant lower or upper bounds on the operands.
+ SmallVector<Optional<int64_t>> constLowerBounds, constUpperBounds;
+ constLowerBounds.reserve(operands.size());
+ constUpperBounds.reserve(operands.size());
+ for (Value operand : operands) {
+ constLowerBounds.push_back(getLowerBound(operand));
+ constUpperBounds.push_back(getUpperBound(operand));
+ }
+
+ if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
+ return constExpr.getValue();
+
+ return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
+ constUpperBounds,
+ /*isUpper=*/true);
+}
+
+/// Determine a constant lower bound for `expr` if one exists while exploiting
+/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
+/// is guaranteed to be less than or equal to it.
+static Optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
+ unsigned numSymbols,
+ ArrayRef<Value> operands) {
+ // Get the constant lower or upper bounds on the operands.
+ SmallVector<Optional<int64_t>> constLowerBounds, constUpperBounds;
+ constLowerBounds.reserve(operands.size());
+ constUpperBounds.reserve(operands.size());
+ for (Value operand : operands) {
+ constLowerBounds.push_back(getLowerBound(operand));
+ constUpperBounds.push_back(getUpperBound(operand));
+ }
+
+ Optional<int64_t> lowerBound;
+ if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ lowerBound = constExpr.getValue();
+ } else {
+ lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
+ constUpperBounds,
+ /*isUpper=*/false);
+ }
+ return lowerBound;
+}
+
/// Simplify `expr` while exploiting information from the values in `operands`.
-static void simplifyExprAndOperands(AffineExpr &expr,
+static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
+ unsigned numSymbols,
ArrayRef<Value> operands) {
// We do this only for certain floordiv/mod expressions.
auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
// Simplify the child expressions first.
AffineExpr lhs = binExpr.getLHS();
AffineExpr rhs = binExpr.getRHS();
- simplifyExprAndOperands(lhs, operands);
- simplifyExprAndOperands(rhs, operands);
+ simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
+ simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
- if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv &&
- binExpr.getKind() != AffineExprKind::Mod)) {
+ if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
+ expr.getKind() != AffineExprKind::CeilDiv &&
+ expr.getKind() != AffineExprKind::Mod)) {
return;
}
int64_t rhsConstVal = rhsConst.getValue();
// Undefined exprsessions aren't touched; IR can still be valid with them.
- if (rhsConstVal == 0)
+ if (rhsConstVal <= 0)
return;
- AffineExpr quotientTimesDiv, rem;
- int64_t divisor;
+ // Exploit constant lower/upper bounds to simplify a floordiv or mod.
+ MLIRContext *context = expr.getContext();
+ std::optional<int64_t> lhsLbConst =
+ getLowerBound(lhs, numDims, numSymbols, operands);
+ std::optional<int64_t> lhsUbConst =
+ getUpperBound(lhs, numDims, numSymbols, operands);
+ if (lhsLbConst && lhsUbConst) {
+ int64_t lhsLbConstVal = *lhsLbConst;
+ int64_t lhsUbConstVal = *lhsUbConst;
+ // lhs floordiv c is a single value lhs is bounded in a range `c` that has
+ // the same quotient.
+ if (binExpr.getKind() == AffineExprKind::FloorDiv &&
+ floorDiv(lhsLbConstVal, rhsConstVal) ==
+ floorDiv(lhsUbConstVal, rhsConstVal)) {
+ expr =
+ getAffineConstantExpr(floorDiv(lhsLbConstVal, rhsConstVal), context);
+ return;
+ }
+ // lhs ceildiv c is a single value if the entire range has the same ceil
+ // quotient.
+ if (binExpr.getKind() == AffineExprKind::CeilDiv &&
+ ceilDiv(lhsLbConstVal, rhsConstVal) ==
+ ceilDiv(lhsUbConstVal, rhsConstVal)) {
+ expr =
+ getAffineConstantExpr(ceilDiv(lhsLbConstVal, rhsConstVal), context);
+ return;
+ }
+ // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
+ if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
+ lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
+ expr = lhs;
+ return;
+ }
+ }
// Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
// mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
// `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
// And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
+ AffineExpr quotientTimesDiv, rem;
+ int64_t divisor;
if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
if (rhsConstVal % divisor == 0 &&
binExpr.getKind() == AffineExprKind::FloorDiv) {
SmallVector<AffineExpr> newResults;
newResults.reserve(map.getNumResults());
for (AffineExpr expr : map.getResults()) {
- simplifyExprAndOperands(expr, operands);
+ simplifyExprAndOperands(expr, map.getNumDims(), map.getNumSymbols(),
+ operands);
newResults.push_back(expr);
}
map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
"test.foo"(%x) : (f32) -> ()
// %i is aligned at 32 boundary and %ii < 32.
- // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 32]
- %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 32] : memref<?x32xf32>
+ // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 16]
+ %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 16] : memref<?x32xf32>
"test.foo"(%a) : (f32) -> ()
// CHECK: affine.load %{{.*}}[%[[I]] floordiv 64, (%[[I]] + %[[II]]) mod 64]
%b = affine.load %A[(%i + %ii) floordiv 64, (%i + %ii) mod 64] : memref<?x32xf32>
return
}
+// CHECK-LABEL: func @simplify_div_mod_with_operands
+func.func @simplify_div_mod_with_operands(%N: index, %A: memref<64xf32>, %unknown: index) {
+ // CHECK: affine.for %[[I:.*]] = 0 to 32
+ %cst = arith.constant 1.0 : f32
+ affine.for %i = 0 to 32 {
+ // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+ affine.store %cst, %A[%i floordiv 32] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[1]
+ affine.store %cst, %A[(%i + 1) ceildiv 32] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[%[[I]]]
+ affine.store %cst, %A[%i mod 32] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+ affine.store %cst, %A[2 * %i floordiv 64] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[0]
+ affine.store %cst, %A[(%i mod 16) floordiv 16] : memref<64xf32>
+
+ // The ones below can't be simplified.
+ affine.store %cst, %A[%i floordiv 16] : memref<64xf32>
+ affine.store %cst, %A[%i mod 16] : memref<64xf32>
+ affine.store %cst, %A[(%i mod 16) floordiv 15] : memref<64xf32>
+ affine.store %cst, %A[%i mod 31] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 16] : memref<64xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 16] : memref<64xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[(%{{.*}} mod 16) floordiv 15] : memref<64xf32>
+ // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 31] : memref<64xf32>
+ }
+
+ affine.for %i = -8 to 32 {
+ // Can't be simplified.
+ // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} floordiv 32] : memref<64xf32>
+ affine.store %cst, %A[%i floordiv 32] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} mod 32] : memref<64xf32>
+ affine.store %cst, %A[%i mod 32] : memref<64xf32>
+ // floordiv rounds toward -inf; (%i - 96) floordiv 64 will be -2.
+ // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+ affine.store %cst, %A[2 + (%i - 96) floordiv 64] : memref<64xf32>
+ }
+
+ // CHECK: affine.for %[[II:.*]] = 8 to 16
+ affine.for %i = 8 to 16 {
+ // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32>
+ affine.store %cst, %A[%i floordiv 8] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[2] : memref<64xf32>
+ affine.store %cst, %A[(%i + 1) ceildiv 8] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]] mod 8] : memref<64xf32>
+ affine.store %cst, %A[%i mod 8] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[%[[II]]] : memref<64xf32>
+ affine.store %cst, %A[%i mod 32] : memref<64xf32>
+ // Upper bound on the mod 32 expression will be 15.
+ // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+ affine.store %cst, %A[(%i mod 32) floordiv 16] : memref<64xf32>
+ // Lower bound on the mod 16 expression will be 8.
+ // CHECK: affine.store %{{.*}}, %{{.*}}[1] : memref<64xf32>
+ affine.store %cst, %A[(%i mod 16) floordiv 8] : memref<64xf32>
+ // CHECK: affine.store %{{.*}}, %{{.*}}[0] : memref<64xf32>
+ affine.store %cst, %A[(%unknown mod 16) floordiv 16] : memref<64xf32>
+ }
+ return
+}
+
// -----
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 * ((-s0 + 40961) ceildiv 512))>