[mlir] NFC - Expose scf::canonicalizeMinMaxOp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 27 Dec 2022 13:47:02 +0000 (05:47 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 27 Dec 2022 13:47:07 +0000 (05:47 -0800)
mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h
mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp

index 462d6b5..4c7dc28 100644 (file)
 #define MLIR_DIALECT_SCF_UTILS_AFFINECANONICALIZATIONUTILS_H_
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 
 namespace mlir {
+class AffineApplyOp;
 class AffineMap;
+class FlatAffineValueConstraints;
 struct LogicalResult;
 class Operation;
 class OpFoldResult;
@@ -51,6 +54,49 @@ LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op,
                                          AffineMap map, ValueRange operands,
                                          bool isMin, LoopMatcherFn loopMatcher);
 
+/// Attempt to canonicalize min/max operations by proving that their value is
+/// bounded by the same lower and upper bound. In such cases, the operation can
+/// be folded away.
+///
+/// Bounds are computed by FlatAffineValueConstraints. Invariants required for
+/// finding/proving bounds should be supplied via `constraints`.
+///
+/// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
+/// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
+///    case of `!isMin`) and bind it to `opBound`. SSA values that are used in
+///    `op` but are not part of `constraints`, are added as extra symbols.
+/// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
+///    * If `isMin`: r_i >= opBound
+///    * If `isMax`: r_i <= opBound
+///    If this is the case, ub(op) == lb(op).
+/// 4. Replace `op` with `opBound`.
+///
+/// In summary, the following constraints are added throughout this function.
+/// Note: `invar` are dimensions added by the caller to express the invariants.
+/// (Showing only the case where `isMin`.)
+///
+///  invar |    op | opBound | r_i | extra syms... | const |           eq/ineq
+///  ------+-------+---------+-----+---------------+-------+-------------------
+///   (various eq./ineq. constraining `invar`, added by the caller)
+///    ... |     0 |       0 |   0 |             0 |   ... |               ...
+///  ------+-------+---------+-----+---------------+-------+-------------------
+///   (various ineq. constraining `op` in terms of `op` operands (`invar` and
+///    extra `op` operands "extra syms" that are not in `invar`)).
+///    ... |    -1 |       0 |   0 |           ... |   ... |              >= 0
+///  ------+-------+---------+-----+---------------+-------+-------------------
+///   (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
+///    ... |     0 |      -1 |   0 |           ... |   ... |               = 0
+///  ------+-------+---------+-----+---------------+-------+-------------------
+///   (for each `op` map result r_i: set r_i to corresponding map result,
+///    prove that r_i >= minOpUb via contradiction)
+///    ... |     0 |       0 |  -1 |           ... |   ... |               = 0
+///      0 |     0 |       1 |  -1 |             0 |    -1 |              >= 0
+///
+FailureOr<AffineApplyOp>
+canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
+                     ValueRange operands, bool isMin,
+                     FlatAffineValueConstraints constraints);
+
 /// Try to simplify a min/max operation `op` after loop peeling. This function
 /// can simplify min/max operations such as (ub is the previous upper bound of
 /// the unpeeled loop):
index 7b13e3c..7b416ee 100644 (file)
@@ -62,48 +62,10 @@ static AffineMap addConstToResults(AffineMap map, int64_t val) {
                         map.getContext());
 }
 
-/// This function tries to canonicalize min/max operations by proving that their
-/// value is bounded by the same lower and upper bound. In that case, the
-/// operation can be folded away.
-///
-/// Bounds are computed by FlatAffineValueConstraints. Invariants required for
-/// finding/proving bounds should be supplied via `constraints`.
-///
-/// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
-/// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
-///    case of `!isMin`) and bind it to `opBound`. SSA values that are used in
-///    `op` but are not part of `constraints`, are added as extra symbols.
-/// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
-///    * If `isMin`: r_i >= opBound
-///    * If `isMax`: r_i <= opBound
-///    If this is the case, ub(op) == lb(op).
-/// 4. Replace `op` with `opBound`.
-///
-/// In summary, the following constraints are added throughout this function.
-/// Note: `invar` are dimensions added by the caller to express the invariants.
-/// (Showing only the case where `isMin`.)
-///
-///  invar |    op | opBound | r_i | extra syms... | const |           eq/ineq
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (various eq./ineq. constraining `invar`, added by the caller)
-///    ... |     0 |       0 |   0 |             0 |   ... |               ...
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (various ineq. constraining `op` in terms of `op` operands (`invar` and
-///    extra `op` operands "extra syms" that are not in `invar`)).
-///    ... |    -1 |       0 |   0 |           ... |   ... |              >= 0
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
-///    ... |     0 |      -1 |   0 |           ... |   ... |               = 0
-///  ------+-------+---------+-----+---------------+-------+-------------------
-///   (for each `op` map result r_i: set r_i to corresponding map result,
-///    prove that r_i >= minOpUb via contradiction)
-///    ... |     0 |       0 |  -1 |           ... |   ... |               = 0
-///      0 |     0 |       1 |  -1 |             0 |    -1 |              >= 0
-///
-static LogicalResult
-canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
-                     ValueRange operands, bool isMin,
-                     FlatAffineValueConstraints constraints) {
+FailureOr<AffineApplyOp>
+scf::canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
+                          ValueRange operands, bool isMin,
+                          FlatAffineValueConstraints constraints) {
   RewriterBase::InsertionGuard guard(rewriter);
   unsigned numResults = map.getNumResults();
 
@@ -195,8 +157,7 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
   }
   mlir::canonicalizeMapAndOperands(&newMap, &newOperands);
   rewriter.setInsertionPoint(op);
-  rewriter.replaceOpWithNewOp<AffineApplyOp>(op, newMap, newOperands);
-  return success();
+  return rewriter.replaceOpWithNewOp<AffineApplyOp>(op, newMap, newOperands);
 }
 
 static LogicalResult