From 93d640f3922b2a15501101b229f8be40e8528a63 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 22 Feb 2023 10:33:09 +0100 Subject: [PATCH] [mlir][SCF][Utils][NFC] Make some utils public for better reuse These functions will be used in a subsequent change. Also some minor refactoring. Differential Revision: https://reviews.llvm.org/D143909 --- .../SCF/Utils/AffineCanonicalizationUtils.h | 10 +++ .../SCF/Transforms/LoopCanonicalization.cpp | 35 +-------- .../SCF/Utils/AffineCanonicalizationUtils.cpp | 84 +++++++++++++++------- 3 files changed, 68 insertions(+), 61 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h index fea8704..88c93db 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h @@ -39,6 +39,16 @@ class IfOp; using LoopMatcherFn = function_ref; +/// Match "for loop"-like operations from the SCF dialect. +LogicalResult matchForLikeLoop(Value iv, OpFoldResult &lb, OpFoldResult &ub, + OpFoldResult &step); + +/// Populate the given constraint set with induction variable constraints of a +/// "for" loop with the given range and step. +LogicalResult addLoopRangeConstraints(FlatAffineValueConstraints &cstr, + Value iv, OpFoldResult lb, + OpFoldResult ub, OpFoldResult step); + /// Try to canonicalize the given affine.min/max operation in the context of /// for `loops` with a known range. /// diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index aee1063..8cbca1b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -158,40 +158,7 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub, - OpFoldResult &step) { - if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { - lb = forOp.getLowerBound(); - ub = forOp.getUpperBound(); - step = forOp.getStep(); - return success(); - } - if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { - for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { - if (parOp.getInductionVars()[idx] == iv) { - lb = parOp.getLowerBound()[idx]; - ub = parOp.getUpperBound()[idx]; - step = parOp.getStep()[idx]; - return success(); - } - } - return failure(); - } - if (scf::ForallOp forallOp = scf::getForallOpThreadIndexOwner(iv)) { - for (int64_t idx = 0; idx < forallOp.getRank(); ++idx) { - if (forallOp.getInductionVar(idx) == iv) { - lb = forallOp.getMixedLowerBound()[idx]; - ub = forallOp.getMixedUpperBound()[idx]; - step = forallOp.getMixedStep()[idx]; - return success(); - } - } - return failure(); - } - return failure(); - }; - - return scf::canonicalizeMinMaxOpInLoop(rewriter, op, loopMatcher); + return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop); } }; diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp index 4ee27e4..7799fa9 100644 --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -12,12 +12,12 @@ #include -#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" @@ -29,6 +29,39 @@ using namespace mlir; using namespace presburger; +LogicalResult scf::matchForLikeLoop(Value iv, OpFoldResult &lb, + OpFoldResult &ub, OpFoldResult &step) { + if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { + lb = forOp.getLowerBound(); + ub = forOp.getUpperBound(); + step = forOp.getStep(); + return success(); + } + if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { + for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { + if (parOp.getInductionVars()[idx] == iv) { + lb = parOp.getLowerBound()[idx]; + ub = parOp.getUpperBound()[idx]; + step = parOp.getStep()[idx]; + return success(); + } + } + return failure(); + } + if (scf::ForallOp forallOp = scf::getForallOpThreadIndexOwner(iv)) { + for (int64_t idx = 0; idx < forallOp.getRank(); ++idx) { + if (forallOp.getInductionVar(idx) == iv) { + lb = forallOp.getMixedLowerBound()[idx]; + ub = forallOp.getMixedUpperBound()[idx]; + step = forallOp.getMixedStep()[idx]; + return success(); + } + } + return failure(); + } + return failure(); +} + static FailureOr canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, FlatAffineValueConstraints constraints) { @@ -42,37 +75,38 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, op, simplified->getAffineMap(), simplified->getOperands()); } -static LogicalResult -addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv, - OpFoldResult lb, OpFoldResult ub, OpFoldResult step, - RewriterBase &rewriter) { +LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr, + Value iv, OpFoldResult lb, + OpFoldResult ub, OpFoldResult step) { + Builder b(iv.getContext()); + // IntegerPolyhedron does not support semi-affine expressions. // Therefore, only constant step values are supported. auto stepInt = getConstantIntValue(step); if (!stepInt) return failure(); - unsigned dimIv = constraints.appendDimVar(iv); + unsigned dimIv = cstr.appendDimVar(iv); auto lbv = lb.dyn_cast(); - unsigned symLb = lbv ? constraints.appendSymbolVar(lbv) - : constraints.appendSymbolVar(/*num=*/1); + unsigned symLb = + lbv ? cstr.appendSymbolVar(lbv) : cstr.appendSymbolVar(/*num=*/1); auto ubv = ub.dyn_cast(); - unsigned symUb = ubv ? constraints.appendSymbolVar(ubv) - : constraints.appendSymbolVar(/*num=*/1); + unsigned symUb = + ubv ? cstr.appendSymbolVar(ubv) : cstr.appendSymbolVar(/*num=*/1); // If loop lower/upper bounds are constant: Add EQ constraint. std::optional lbInt = getConstantIntValue(lb); std::optional ubInt = getConstantIntValue(ub); if (lbInt) - constraints.addBound(IntegerPolyhedron::EQ, symLb, *lbInt); + cstr.addBound(IntegerPolyhedron::EQ, symLb, *lbInt); if (ubInt) - constraints.addBound(IntegerPolyhedron::EQ, symUb, *ubInt); + cstr.addBound(IntegerPolyhedron::EQ, symUb, *ubInt); // Lower bound: iv >= lb (equiv.: iv - lb >= 0) - SmallVector ineqLb(constraints.getNumCols(), 0); + SmallVector ineqLb(cstr.getNumCols(), 0); ineqLb[dimIv] = 1; ineqLb[symLb] = -1; - constraints.addInequality(ineqLb); + cstr.addInequality(ineqLb); // Upper bound AffineExpr ivUb; @@ -81,26 +115,23 @@ addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv, // iv < lb + 1 // TODO: Try to derive this constraint by simplifying the expression in // the else-branch. - ivUb = - rewriter.getAffineSymbolExpr(symLb - constraints.getNumDimVars()) + 1; + ivUb = b.getAffineSymbolExpr(symLb - cstr.getNumDimVars()) + 1; } else { // The loop may have more than one iteration. // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 AffineExpr exprLb = - lbInt - ? rewriter.getAffineConstantExpr(*lbInt) - : rewriter.getAffineSymbolExpr(symLb - constraints.getNumDimVars()); + lbInt ? b.getAffineConstantExpr(*lbInt) + : b.getAffineSymbolExpr(symLb - cstr.getNumDimVars()); AffineExpr exprUb = - ubInt - ? rewriter.getAffineConstantExpr(*ubInt) - : rewriter.getAffineSymbolExpr(symUb - constraints.getNumDimVars()); + ubInt ? b.getAffineConstantExpr(*ubInt) + : b.getAffineSymbolExpr(symUb - cstr.getNumDimVars()); ivUb = exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt))); } auto map = AffineMap::get( - /*dimCount=*/constraints.getNumDimVars(), - /*symbolCount=*/constraints.getNumSymbolVars(), /*result=*/ivUb); + /*dimCount=*/cstr.getNumDimVars(), + /*symbolCount=*/cstr.getNumSymbolVars(), /*result=*/ivUb); - return constraints.addBound(IntegerPolyhedron::UB, dimIv, map); + return cstr.addBound(IntegerPolyhedron::UB, dimIv, map); } /// Canonicalize min/max operations in the context of for loops with a known @@ -132,8 +163,7 @@ LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, continue; allIvs.insert(iv); - if (failed( - addLoopRangeConstraints(constraints, iv, lb, ub, step, rewriter))) + if (failed(addLoopRangeConstraints(constraints, iv, lb, ub, step))) return failure(); } -- 2.7.4