#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include <numeric>
using namespace mlir;
return result[0];
}
+/// Returns the largest known divisor of `e`. Exploits information from the
+/// values in `operands`.
+static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
+ // This method isn't aware of `operands`.
+ int64_t div = e.getLargestKnownDivisor();
+
+ // We now make use of operands for the case `e` is a dim expression.
+ // TODO: More powerful simplification would have to modify
+ // getLargestKnownDivisor to take `operands` and exploit that information as
+ // well for dim/sym expressions, but in that case, getLargestKnownDivisor
+ // can't be part of the IR library but of the `Analysis` library. The IR
+ // library can only really depend on simple O(1) checks.
+ auto dimExpr = e.dyn_cast<AffineDimExpr>();
+ // If it's not a dim expr, `div` is the best we have.
+ if (!dimExpr)
+ return div;
+
+ // We simply exploit information from loop IVs.
+ // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
+ // desired simplifications are expected to be part of other
+ // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
+ // LoopAnalysis library.
+ Value operand = operands[dimExpr.getPosition()];
+ int64_t operandDivisor = 1;
+ // TODO: With the right accessors, this can be extended to
+ // LoopLikeOpInterface.
+ if (AffineForOp forOp = getForInductionVarOwner(operand)) {
+ if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
+ operandDivisor = forOp.getStep();
+ } else {
+ uint64_t lbLargestKnownDivisor =
+ forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
+ operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStep());
+ }
+ }
+ return operandDivisor;
+}
+
+/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
+/// being an affine dim expression or a constant.
+static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
+ int64_t k) {
+ if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
+ int64_t constVal = constExpr.getValue();
+ return constVal >= 0 && constVal < k;
+ }
+ auto dimExpr = e.dyn_cast<AffineDimExpr>();
+ if (!dimExpr)
+ return false;
+ Value operand = operands[dimExpr.getPosition()];
+ // TODO: With the right accessors, this can be extended to
+ // LoopLikeOpInterface.
+ if (AffineForOp forOp = getForInductionVarOwner(operand)) {
+ if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
+ forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
+ return true;
+ }
+ }
+
+ // We don't consider other cases like `operand` being defined by a constant or
+ // an affine.apply op since such cases will already be handled by other
+ // patterns and propagation of loop IVs or constant would happen.
+ return false;
+}
+
+/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
+/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
+/// expression is in that form.
+static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
+ AffineExpr "ientTimesDiv, AffineExpr &rem) {
+ auto bin = e.dyn_cast<AffineBinaryOpExpr>();
+ if (!bin || bin.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr llhs = bin.getLHS();
+ AffineExpr rlhs = bin.getRHS();
+ div = getLargestKnownDivisor(llhs, operands);
+ if (isNonNegativeBoundedBy(rlhs, operands, div)) {
+ quotientTimesDiv = llhs;
+ rem = rlhs;
+ return true;
+ }
+ div = getLargestKnownDivisor(rlhs, operands);
+ if (isNonNegativeBoundedBy(llhs, operands, div)) {
+ quotientTimesDiv = rlhs;
+ rem = llhs;
+ return true;
+ }
+ return false;
+}
+
+/// Simplify `expr` while exploiting information from the values in `operands`.
+static void simplifyExprAndOperands(AffineExpr &expr,
+ ArrayRef<Value> operands) {
+ // We do this only for certain floordiv/mod expressions.
+ auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
+ if (!binExpr)
+ return;
+
+ // Simplify the child expressions first.
+ auto lhs = binExpr.getLHS();
+ auto rhs = binExpr.getRHS();
+ simplifyExprAndOperands(lhs, operands);
+ simplifyExprAndOperands(rhs, operands);
+ expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
+
+ binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
+ if (!binExpr || (binExpr.getKind() != AffineExprKind::FloorDiv &&
+ binExpr.getKind() != AffineExprKind::Mod)) {
+ return;
+ }
+
+ auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ if (!rhsConst)
+ return;
+
+ int64_t rhsConstVal = rhsConst.getValue();
+ AffineExpr quotientTimesDiv, rem;
+ int64_t divisor;
+
+ // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
+ // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
+ // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
+ // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
+ if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
+ if (rhsConstVal % divisor == 0 &&
+ binExpr.getKind() == AffineExprKind::FloorDiv) {
+ expr = quotientTimesDiv.floorDiv(rhsConst);
+ } else if (divisor % rhsConstVal == 0 &&
+ binExpr.getKind() == AffineExprKind::Mod) {
+ expr = rem % rhsConst;
+ }
+ return;
+ }
+
+ // Handle the simple case when the LHS expression can be either upper
+ // bounded or is a known multiple of RHS constant.
+ // lhs floordiv c -> 0 if 0 <= lhs < c,
+ // lhs mod c -> 0 if lhs % c = 0.
+ if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
+ binExpr.getKind() == AffineExprKind::FloorDiv) ||
+ (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
+ binExpr.getKind() == AffineExprKind::Mod)) {
+ expr = getAffineConstantExpr(0, expr.getContext());
+ }
+}
+
+/// Simplify the map while exploiting information on the values in `operands`.
+// Use "unused attribute" marker to silence warning stemming from the inability
+// to see through the template expansion.
+static void LLVM_ATTRIBUTE_UNUSED
+simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
+ assert(map.getNumInputs() == operands.size() && "invalid operands for map");
+ SmallVector<AffineExpr> newResults;
+ newResults.reserve(map.getNumResults());
+ for (AffineExpr expr : map.getResults()) {
+ simplifyExprAndOperands(expr, operands);
+ newResults.push_back(expr);
+ }
+ map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
+ map.getContext());
+}
+
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
SmallVector<Value, 8> resultOperands(oldOperands);
composeAffineMapAndOperands(&map, &resultOperands);
canonicalizeMapAndOperands(&map, &resultOperands);
+ simplifyMapWithOperands(map, resultOperands);
if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
resultOperands.begin()))
return failure();
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
- affine.for %i0 = 0 to 3 {
+ affine.for %i0 = 0 to 16 {
%x0 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)> (%i0)[%c4]
- affine.for %i1 = 0 to 3 {
+ affine.for %i1 = 0 to 16 {
%x1 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv s0)> (%i1)[%c8]
- affine.for %i2 = 0 to 3 {
+ affine.for %i2 = 0 to 16 {
%x2 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)> (%i2)[%c4]
- affine.for %i3 = 0 to 3 {
+ affine.for %i3 = 0 to 16 {
%x3 = affine.apply affine_map<(d0)[s0] -> (d0 mod s0)> (%i3)[%c8]
%x40 = affine.apply affine_map<(d0, d1, d2, d3)[s0, s1] ->
return %s: memref<32x64xf32>
}
}
+
+// -----
+
+// Simplification of maps exploiting operand info.
+
+// CHECK-LABEL: func @simplify_with_operands
+func.func @simplify_with_operands(%N: index, %A: memref<?x32xf32>) {
+ // CHECK-NEXT: affine.for %[[I:.*]] = 0 to %{{.*}}
+ affine.for %i = 0 to %N step 32 {
+ // CHECK-NEXT: affine.for %[[II:.*]] = 0 to 32
+ affine.for %ii = 0 to 32 {
+ // %ii is less than 32 and %i divides 32.
+ // CHECK: affine.load %{{.*}}[0, 0]
+ %x = affine.load %A[%ii floordiv 32, %i mod 32] : memref<?x32xf32>
+ "test.foo"(%x) : (f32) -> ()
+
+ // %i is aligned at 32 boundary and %ii < 32.
+ // CHECK: affine.load %{{.*}}[%[[I]] floordiv 32, %[[II]] mod 32]
+ %a = affine.load %A[(%i + %ii) floordiv 32, (%i + %ii) mod 32] : memref<?x32xf32>
+ "test.foo"(%a) : (f32) -> ()
+ // CHECK: affine.load %{{.*}}[%[[I]] floordiv 64, (%[[I]] + %[[II]]) mod 64]
+ %b = affine.load %A[(%i + %ii) floordiv 64, (%i + %ii) mod 64] : memref<?x32xf32>
+ "test.foo"(%b) : (f32) -> ()
+ // CHECK: affine.load %{{.*}}[(%[[I]] + %[[II]]) floordiv 16, %[[II]] mod 16]
+ %c = affine.load %A[(%i + %ii) floordiv 16, (%i + %ii) mod 16] : memref<?x32xf32>
+ "test.foo"(%c) : (f32) -> ()
+ }
+ }
+
+ // Should not simplify.
+ affine.for %i = -1 to 32 {
+ // CHECK: affine.load %{{.*}}[%{{.*}} floordiv {{.*}}, %{{.*}} mod {{.*}}] :
+ %x = affine.load %A[%i floordiv 32, %i mod 32] : memref<?x32xf32>
+ "test.foo"(%x) : (f32) -> ()
+ }
+
+ return
+}