[mlir][affine][NFC] Extract core functionality of `canonicalizeMinMaxOp`
authorMatthias Springer <springerm@google.com>
Wed, 4 Jan 2023 09:56:43 +0000 (10:56 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 4 Jan 2023 10:25:44 +0000 (11:25 +0100)
Move code from SCF to Affine: Add a new helper function `simplifyConstrainedMinMaxOp` to Affine/Analysis/Utils.h. `canonicalizeMinMaxOp` was originally designed for loop peeling, but it is not SCF-specific and can be used to simplify any affine.min/max ops.

Various functions in SCF/Transforms are simplified by dropping unnecessary parameters.

Differential Revision: https://reviews.llvm.org/D140962

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
mlir/lib/Dialect/Affine/Analysis/CMakeLists.txt
mlir/lib/Dialect/Affine/Analysis/Utils.cpp
mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index 3ce4ea3..d4ea3ea 100644 (file)
@@ -27,6 +27,7 @@
 namespace mlir {
 
 class AffineForOp;
+class AffineValueMap;
 class Block;
 class Location;
 struct MemRefAccess;
@@ -384,6 +385,13 @@ unsigned getInnermostCommonLoopDepth(
     ArrayRef<Operation *> ops,
     SmallVectorImpl<AffineForOp> *surroundingLoops = nullptr);
 
+/// Try to simplify the given affine.min or affine.max op to an affine map with
+/// a single result and operands, taking into account the specified constraint
+/// set. Return failure if no simplified version could be found.
+FailureOr<AffineValueMap>
+simplifyConstrainedMinMaxOp(Operation *op,
+                            FlatAffineValueConstraints constraints);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AFFINE_ANALYSIS_UTILS_H
index 4c7dc28..fea8704 100644 (file)
@@ -39,70 +39,23 @@ class IfOp;
 using LoopMatcherFn = function_ref<LogicalResult(
     Value, OpFoldResult &, OpFoldResult &, OpFoldResult &)>;
 
-/// Try to canonicalize an min/max operations in the context of for `loops` with
-/// a known range.
+/// Try to canonicalize the given affine.min/max operation in the context of
+/// for `loops` with a known range.
 ///
-/// `map` is the body of the min/max operation and `operands` are the SSA values
-/// that the dimensions and symbols are bound to; dimensions are listed first.
-/// If `isMin`, the operation is a min operation; otherwise, a max operation.
 /// `loopMatcher` is used to retrieve loop bounds and the step size for a given
 /// iteration variable.
 ///
 /// Note: `loopMatcher` allows this function to be used with any "for loop"-like
 /// operation (scf.for, scf.parallel and even ops defined in other dialects).
 LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op,
-                                         AffineMap map, ValueRange operands,
-                                         bool isMin, LoopMatcherFn loopMatcher);
+                                         LoopMatcherFn loopMatcher);
 
-/// Attempt to canonicalize min/max operations by proving that their value is
-/// bounded by the same lower and upper bound. In such cases, the operation can
-/// be folded away.
-///
-/// Bounds are computed by FlatAffineValueConstraints. Invariants required for
-/// finding/proving bounds should be supplied via `constraints`.
-///
-/// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
-/// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
-///    case of `!isMin`) and bind it to `opBound`. SSA values that are used in
-///    `op` but are not part of `constraints`, are added as extra symbols.
-/// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
-///    * If `isMin`: r_i >= opBound
-///    * If `isMax`: r_i <= opBound
-///    If this is the case, ub(op) == lb(op).
-/// 4. Replace `op` with `opBound`.
-///
-/// In summary, the following constraints are added throughout this function.
-/// Note: `invar` are dimensions added by the caller to express the invariants.
-/// (Showing only the case where `isMin`.)
-///
-///  invar |    op | opBound | r_i | extra syms... | const |           eq/ineq
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (various eq./ineq. constraining `invar`, added by the caller)
-///    ... |     0 |       0 |   0 |             0 |   ... |               ...
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (various ineq. constraining `op` in terms of `op` operands (`invar` and
-///    extra `op` operands "extra syms" that are not in `invar`)).
-///    ... |    -1 |       0 |   0 |           ... |   ... |              >= 0
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
-///    ... |     0 |      -1 |   0 |           ... |   ... |               = 0
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (for each `op` map result r_i: set r_i to corresponding map result,
-///    prove that r_i >= minOpUb via contradiction)
-///    ... |     0 |       0 |  -1 |           ... |   ... |               = 0
-///      0 |     0 |       1 |  -1 |             0 |    -1 |              >= 0
-///
-FailureOr<AffineApplyOp>
-canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
-                     ValueRange operands, bool isMin,
-                     FlatAffineValueConstraints constraints);
-
-/// Try to simplify a min/max operation `op` after loop peeling. This function
-/// can simplify min/max operations such as (ub is the previous upper bound of
-/// the unpeeled loop):
+/// Try to simplify the given affine.min/max operation `op` after loop peeling.
+/// This function can simplify min/max operations such as (ub is the previous
+/// upper bound of the unpeeled loop):
 /// ```
 /// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
-/// %r = affine.min #affine.min #map(%iv)[%step, %ub]
+/// %r = affine.min #map(%iv)[%step, %ub]
 /// ```
 /// and rewrites them into (in the case the peeled loop):
 /// ```
@@ -111,8 +64,7 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
 /// min/max operations inside the partial iteration are rewritten in a similar
 /// way.
 LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
-                                    AffineMap map, ValueRange operands,
-                                    bool isMin, Value iv, Value ub, Value step,
+                                    Value iv, Value ub, Value step,
                                     bool insideLoop);
 
 } // namespace scf
index 38b64d5..61e49b0 100644 (file)
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineAnalysis
   MLIRAnalysis
   MLIRCallInterfaces
   MLIRControlFlowInterfaces
+  MLIRDialectUtils
   MLIRInferTypeOpInterface
   MLIRSideEffectInterfaces
   MLIRPresburger
index 47e26e0..537c718 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IntegerSet.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
@@ -1362,3 +1363,184 @@ IntegerSet mlir::simplifyIntegerSet(IntegerSet set) {
   assert(simplifiedSet && "guaranteed to succeed while roundtripping");
   return simplifiedSet;
 }
+
+static void unpackOptionalValues(ArrayRef<Optional<Value>> source,
+                                 SmallVector<Value> &target) {
+  target = llvm::to_vector<4>(llvm::map_range(source, [](Optional<Value> val) {
+    return val.has_value() ? *val : Value();
+  }));
+}
+
+/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
+/// constraints drawn from an affine map. Before adding the constraint, the
+/// dimensions/symbols of the affine map are aligned with `constraints`.
+/// `operands` are the SSA Value operands used with the affine map.
+/// Note: This function adds a new symbol column to the `constraints` for each
+/// dimension/symbol that exists in the affine map but not in `constraints`.
+static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
+                                      IntegerPolyhedron::BoundType type,
+                                      unsigned pos, AffineMap map,
+                                      ValueRange operands) {
+  SmallVector<Value> dims, syms, newSyms;
+  unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
+  unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
+
+  AffineMap alignedMap =
+      alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
+  for (unsigned i = syms.size(); i < newSyms.size(); ++i)
+    constraints.appendSymbolVar(newSyms[i]);
+  return constraints.addBound(type, pos, alignedMap);
+}
+
+/// Add `val` to each result of `map`.
+static AffineMap addConstToResults(AffineMap map, int64_t val) {
+  SmallVector<AffineExpr> newResults;
+  for (AffineExpr r : map.getResults())
+    newResults.push_back(r + val);
+  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
+                        map.getContext());
+}
+
+// Attempt to simplify the given min/max operation by proving that its value is
+// bounded by the same lower and upper bound.
+//
+// Bounds are computed by FlatAffineValueConstraints. Invariants required for
+// finding/proving bounds should be supplied via `constraints`.
+//
+// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
+// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
+//    case of `!isMin`) and bind it to `opBound`. SSA values that are used in
+//    `op` but are not part of `constraints`, are added as extra symbols.
+// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
+//    * If `isMin`: r_i >= opBound
+//    * If `isMax`: r_i <= opBound
+//    If this is the case, ub(op) == lb(op).
+// 4. Replace `op` with `opBound`.
+//
+// In summary, the following constraints are added throughout this function.
+// Note: `invar` are dimensions added by the caller to express the invariants.
+// (Showing only the case where `isMin`.)
+//
+//  invar |    op | opBound | r_i | extra syms... | const |           eq/ineq
+//  ------+-------+---------+-----+---------------+-------+-------------------
+//   (various eq./ineq. constraining `invar`, added by the caller)
+//    ... |     0 |       0 |   0 |             0 |   ... |               ...
+//  ------+-------+---------+-----+---------------+-------+-------------------
+//  (various ineq. constraining `op` in terms of `op` operands (`invar` and
+//    extra `op` operands "extra syms" that are not in `invar`)).
+//    ... |    -1 |       0 |   0 |           ... |   ... |              >= 0
+//  ------+-------+---------+-----+---------------+-------+-------------------
+//   (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
+//    ... |     0 |      -1 |   0 |           ... |   ... |               = 0
+//  ------+-------+---------+-----+---------------+-------+-------------------
+//   (for each `op` map result r_i: set r_i to corresponding map result,
+//    prove that r_i >= minOpUb via contradiction)
+//    ... |     0 |       0 |  -1 |           ... |   ... |               = 0
+//      0 |     0 |       1 |  -1 |             0 |    -1 |              >= 0
+//
+FailureOr<AffineValueMap>
+mlir::simplifyConstrainedMinMaxOp(Operation *op,
+                                  FlatAffineValueConstraints constraints) {
+  bool isMin = isa<AffineMinOp>(op);
+  assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
+  MLIRContext *ctx = op->getContext();
+  Builder builder(ctx);
+  AffineMap map =
+      isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
+  ValueRange operands = op->getOperands();
+  unsigned numResults = map.getNumResults();
+
+  // Add a few extra dimensions.
+  unsigned dimOp = constraints.appendDimVar();      // `op`
+  unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
+  unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
+
+  // Add an inequality for each result expr_i of map:
+  // isMin: op <= expr_i, !isMin: op >= expr_i
+  auto boundType = isMin ? IntegerPolyhedron::UB : IntegerPolyhedron::LB;
+  // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
+  AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
+  if (failed(
+          alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
+    return failure();
+
+  // Try to compute a lower/upper bound for op, expressed in terms of the other
+  // `dims` and extra symbols.
+  SmallVector<AffineMap> opLb(1), opUb(1);
+  constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
+  AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
+  // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
+  // a TODO of `getSliceBounds` and not handled here.
+  if (!sliceBound || sliceBound.getNumResults() != 1)
+    return failure(); // No or multiple bounds found.
+  // Recover the inclusive UB in the case of an `affine.min`.
+  AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
+
+  // Add an equality: Set dimOpBound to computed bound.
+  // Add back dimension for op. (Was removed by `getSliceBounds`.)
+  AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
+  if (failed(constraints.addBound(IntegerPolyhedron::EQ, dimOpBound,
+                                  alignedBoundMap)))
+    return failure();
+
+  // If the constraint system is empty, there is an inconsistency. (E.g., this
+  // can happen if loop lb > ub.)
+  if (constraints.isEmpty())
+    return failure();
+
+  // In the case of `isMin` (`!isMin` is inversed):
+  // Prove that each result of `map` has a lower bound that is equal to (or
+  // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
+  // can be replaced with the bound. I.e., prove that for each result
+  // expr_i (represented by dimension r_i):
+  //
+  // r_i >= opBound
+  //
+  // To prove this inequality, add its negation to the constraint set and prove
+  // that the constraint set is empty.
+  for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
+    FlatAffineValueConstraints newConstr(constraints);
+
+    // Add an equality: r_i = expr_i
+    // Note: These equalities could have been added earlier and used to express
+    // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
+    // computes minOpUb in terms of r_i dims, which is not desired.
+    if (failed(alignAndAddBound(newConstr, IntegerPolyhedron::EQ, i,
+                                map.getSubMap({i - resultDimStart}), operands)))
+      return failure();
+
+    // If `isMin`:  Add inequality: r_i < opBound
+    //              equiv.: opBound - r_i - 1 >= 0
+    // If `!isMin`: Add inequality: r_i > opBound
+    //              equiv.: -opBound + r_i - 1 >= 0
+    SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
+    ineq[dimOpBound] = isMin ? 1 : -1;
+    ineq[i] = isMin ? -1 : 1;
+    ineq[newConstr.getNumCols() - 1] = -1;
+    newConstr.addInequality(ineq);
+    if (!newConstr.isEmpty())
+      return failure();
+  }
+
+  // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
+  AffineMap newMap = alignedBoundMap;
+  SmallVector<Value> newOperands;
+  unpackOptionalValues(constraints.getMaybeValues(), newOperands);
+  // If dims/symbols have known constant values, use those in order to simplify
+  // the affine map further.
+  for (int64_t i = 0, e = constraints.getNumVars(); i < e; ++i) {
+    // Skip unused operands and operands that are already constants.
+    if (!newOperands[i] || getConstantIntValue(newOperands[i]))
+      continue;
+    if (auto bound = constraints.getConstantBound64(IntegerPolyhedron::EQ, i)) {
+      AffineExpr expr =
+          i < newMap.getNumDims()
+              ? builder.getAffineDimExpr(i)
+              : builder.getAffineSymbolExpr(i - newMap.getNumDims());
+      newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
+                              newMap.getNumDims(), newMap.getNumSymbols());
+    }
+  }
+  mlir::canonicalizeMapAndOperands(&newMap, &newOperands);
+  return AffineValueMap(newMap, newOperands);
+}
index 98223f8..e0f236b 100644 (file)
@@ -152,7 +152,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
 
 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
 /// and scf.parallel loops with a known range.
-template <typename OpTy, bool IsMin>
+template <typename OpTy>
 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
 
@@ -192,8 +192,7 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
       return failure();
     };
 
-    return scf::canonicalizeMinMaxOpInLoop(
-        rewriter, op, op.getAffineMap(), op.getOperands(), IsMin, loopMatcher);
+    return scf::canonicalizeMinMaxOpInLoop(rewriter, op, loopMatcher);
   }
 };
 
@@ -214,8 +213,8 @@ void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
     RewritePatternSet &patterns) {
   MLIRContext *ctx = patterns.getContext();
   patterns
-      .add<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
-           AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
+      .add<AffineOpSCFCanonicalizationPattern<AffineMinOp>,
+           AffineOpSCFCanonicalizationPattern<AffineMaxOp>,
            DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
            DimOfLoopResultFolder<tensor::DimOp>,
            DimOfLoopResultFolder<memref::DimOp>>(ctx);
index 3cd26ec..2a83b6d 100644 (file)
@@ -154,7 +154,6 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
   return success();
 }
 
-template <typename OpTy, bool IsMin>
 static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
                                         ForOp partialIteration,
                                         Value previousUb) {
@@ -164,18 +163,20 @@ static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
          "expected same step in main and partial loop");
   Value step = forOp.getStep();
 
-  forOp.walk([&](OpTy affineOp) {
-    AffineMap map = affineOp.getAffineMap();
-    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
-                                     affineOp.getOperands(), IsMin, mainIv,
-                                     previousUb, step,
+  forOp.walk([&](Operation *affineOp) {
+    if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
+      return WalkResult::advance();
+    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, previousUb,
+                                     step,
                                      /*insideLoop=*/true);
+    return WalkResult::advance();
   });
-  partialIteration.walk([&](OpTy affineOp) {
-    AffineMap map = affineOp.getAffineMap();
-    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
-                                     affineOp.getOperands(), IsMin, partialIv,
-                                     previousUb, step, /*insideLoop=*/false);
+  partialIteration.walk([&](Operation *affineOp) {
+    if (!isa<AffineMinOp, AffineMaxOp>(affineOp))
+      return WalkResult::advance();
+    (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb,
+                                     step, /*insideLoop=*/false);
+    return WalkResult::advance();
   });
 }
 
@@ -188,10 +189,7 @@ LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
     return failure();
 
   // Rewrite affine.min and affine.max ops.
-  rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
-      rewriter, forOp, partialIteration, previousUb);
-  rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
-      rewriter, forOp, partialIteration, previousUb);
+  rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb);
 
   return success();
 }
index 7b416ee..0acb511 100644 (file)
@@ -12,7 +12,9 @@
 
 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 using namespace mlir;
 using namespace presburger;
 
-static void unpackOptionalValues(ArrayRef<Optional<Value>> source,
-                                 SmallVector<Value> &target) {
-  target = llvm::to_vector<4>(llvm::map_range(source, [](Optional<Value> val) {
-    return val.has_value() ? *val : Value();
-  }));
-}
-
-/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
-/// constraints drawn from an affine map. Before adding the constraint, the
-/// dimensions/symbols of the affine map are aligned with `constraints`.
-/// `operands` are the SSA Value operands used with the affine map.
-/// Note: This function adds a new symbol column to the `constraints` for each
-/// dimension/symbol that exists in the affine map but not in `constraints`.
-static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
-                                      IntegerPolyhedron::BoundType type,
-                                      unsigned pos, AffineMap map,
-                                      ValueRange operands) {
-  SmallVector<Value> dims, syms, newSyms;
-  unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
-  unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
-
-  AffineMap alignedMap =
-      alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
-  for (unsigned i = syms.size(); i < newSyms.size(); ++i)
-    constraints.appendSymbolVar(newSyms[i]);
-  return constraints.addBound(type, pos, alignedMap);
-}
-
-/// Add `val` to each result of `map`.
-static AffineMap addConstToResults(AffineMap map, int64_t val) {
-  SmallVector<AffineExpr> newResults;
-  for (AffineExpr r : map.getResults())
-    newResults.push_back(r + val);
-  return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
-                        map.getContext());
-}
-
-FailureOr<AffineApplyOp>
-scf::canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
-                          ValueRange operands, bool isMin,
-                          FlatAffineValueConstraints constraints) {
+static FailureOr<AffineApplyOp>
+canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op,
+                     FlatAffineValueConstraints constraints) {
   RewriterBase::InsertionGuard guard(rewriter);
-  unsigned numResults = map.getNumResults();
-
-  // Add a few extra dimensions.
-  unsigned dimOp = constraints.appendDimVar();      // `op`
-  unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
-  unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
-
-  // Add an inequality for each result expr_i of map:
-  // isMin: op <= expr_i, !isMin: op >= expr_i
-  auto boundType = isMin ? IntegerPolyhedron::UB : IntegerPolyhedron::LB;
-  // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
-  AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
-  if (failed(
-          alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
-    return failure();
-
-  // Try to compute a lower/upper bound for op, expressed in terms of the other
-  // `dims` and extra symbols.
-  SmallVector<AffineMap> opLb(1), opUb(1);
-  constraints.getSliceBounds(dimOp, 1, rewriter.getContext(), &opLb, &opUb);
-  AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
-  // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
-  // a TODO of `getSliceBounds` and not handled here.
-  if (!sliceBound || sliceBound.getNumResults() != 1)
-    return failure(); // No or multiple bounds found.
-  // Recover the inclusive UB in the case of an `affine.min`.
-  AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
-
-  // Add an equality: Set dimOpBound to computed bound.
-  // Add back dimension for op. (Was removed by `getSliceBounds`.)
-  AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
-  if (failed(constraints.addBound(IntegerPolyhedron::EQ, dimOpBound,
-                                  alignedBoundMap)))
-    return failure();
-
-  // If the constraint system is empty, there is an inconsistency. (E.g., this
-  // can happen if loop lb > ub.)
-  if (constraints.isEmpty())
-    return failure();
-
-  // In the case of `isMin` (`!isMin` is inversed):
-  // Prove that each result of `map` has a lower bound that is equal to (or
-  // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
-  // can be replaced with the bound. I.e., prove that for each result
-  // expr_i (represented by dimension r_i):
-  //
-  // r_i >= opBound
-  //
-  // To prove this inequality, add its negation to the constraint set and prove
-  // that the constraint set is empty.
-  for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
-    FlatAffineValueConstraints newConstr(constraints);
-
-    // Add an equality: r_i = expr_i
-    // Note: These equalities could have been added earlier and used to express
-    // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
-    // computes minOpUb in terms of r_i dims, which is not desired.
-    if (failed(alignAndAddBound(newConstr, IntegerPolyhedron::EQ, i,
-                                map.getSubMap({i - resultDimStart}), operands)))
-      return failure();
-
-    // If `isMin`:  Add inequality: r_i < opBound
-    //              equiv.: opBound - r_i - 1 >= 0
-    // If `!isMin`: Add inequality: r_i > opBound
-    //              equiv.: -opBound + r_i - 1 >= 0
-    SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
-    ineq[dimOpBound] = isMin ? 1 : -1;
-    ineq[i] = isMin ? -1 : 1;
-    ineq[newConstr.getNumCols() - 1] = -1;
-    newConstr.addInequality(ineq);
-    if (!newConstr.isEmpty())
-      return failure();
-  }
-
-  // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
-  AffineMap newMap = alignedBoundMap;
-  SmallVector<Value> newOperands;
-  unpackOptionalValues(constraints.getMaybeValues(), newOperands);
-  // If dims/symbols have known constant values, use those in order to simplify
-  // the affine map further.
-  for (int64_t i = 0, e = constraints.getNumVars(); i < e; ++i) {
-    // Skip unused operands and operands that are already constants.
-    if (!newOperands[i] || getConstantIntValue(newOperands[i]))
-      continue;
-    if (auto bound = constraints.getConstantBound64(IntegerPolyhedron::EQ, i))
-      newOperands[i] =
-          rewriter.create<arith::ConstantIndexOp>(op->getLoc(), *bound);
-  }
-  mlir::canonicalizeMapAndOperands(&newMap, &newOperands);
   rewriter.setInsertionPoint(op);
-  return rewriter.replaceOpWithNewOp<AffineApplyOp>(op, newMap, newOperands);
+  FailureOr<AffineValueMap> simplified =
+      mlir::simplifyConstrainedMinMaxOp(op, constraints);
+  if (failed(simplified))
+    return failure();
+  return rewriter.replaceOpWithNewOp<AffineApplyOp>(
+      op, simplified->getAffineMap(), simplified->getOperands());
 }
 
 static LogicalResult
@@ -231,14 +111,13 @@ addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
 /// Note: Due to limitations of IntegerPolyhedron, only constant step sizes
 /// are currently supported.
 LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
-                                              Operation *op, AffineMap map,
-                                              ValueRange operands, bool isMin,
+                                              Operation *op,
                                               LoopMatcherFn loopMatcher) {
   FlatAffineValueConstraints constraints;
   DenseSet<Value> allIvs;
 
   // Find all iteration variables among `minOp`'s operands add constrain them.
-  for (Value operand : operands) {
+  for (Value operand : op->getOperands()) {
     // Skip duplicate ivs.
     if (llvm::is_contained(allIvs, operand))
       continue;
@@ -256,12 +135,12 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
       return failure();
   }
 
-  return canonicalizeMinMaxOp(rewriter, op, map, operands, isMin, constraints);
+  return canonicalizeMinMaxOp(rewriter, op, constraints);
 }
 
-/// Try to simplify a min/max operation `op` after loop peeling. This function
-/// can simplify min/max operations such as (ub is the previous upper bound of
-/// the unpeeled loop):
+/// Try to simplify the given affine.min/max operation `op` after loop peeling.
+/// This function can simplify min/max operations such as (ub is the previous
+/// upper bound of the unpeeled loop):
 /// ```
 /// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>
 /// %r = affine.min #affine.min #map(%iv)[%step, %ub]
@@ -285,9 +164,8 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter,
 /// affine.min ops inside the partial iteration. For an explanation of the other
 /// parameters, see comment of `canonicalizeMinMaxOpInLoop`.
 LogicalResult scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
-                                         AffineMap map, ValueRange operands,
-                                         bool isMin, Value iv, Value ub,
-                                         Value step, bool insideLoop) {
+                                         Value iv, Value ub, Value step,
+                                         bool insideLoop) {
   FlatAffineValueConstraints constraints;
   constraints.appendDimVar({iv, ub, step});
   if (auto constUb = getConstantIntValue(ub))
@@ -311,5 +189,5 @@ LogicalResult scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op,
     constraints.addInequality({1, -1, 1, -1});
   }
 
-  return canonicalizeMinMaxOp(rewriter, op, map, operands, isMin, constraints);
+  return canonicalizeMinMaxOp(rewriter, op, constraints);
 }
index f55537e..e64cbb1 100644 (file)
@@ -2641,6 +2641,7 @@ cc_library(
         ":AffineDialect",
         ":Analysis",
         ":ArithDialect",
+        ":DialectUtils",
         ":FuncDialect",
         ":IR",
         ":SideEffectInterfaces",