/// in the map. It skips AddRecExpr because we cannot guarantee that the
/// replacement is loop invariant in the loop of the AddRec.
///
-/// At the moment only rewriting SCEVUnknown is supported.
+/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
+/// supported.
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> ⤅
return Expr;
return I->second;
}
+
+ const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+ auto I = Map.find(Expr);
+ if (I == Map.end())
+ return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
+ Expr);
+ return I->second;
+ }
};
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+ SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
DenseMap<const SCEV *, const SCEV *>
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
RewriteMap[LHSUnknown] = Multiple;
+ ExprsToRewrite.push_back(LHSUnknown);
return;
}
}
// Check for a condition of the form (-C1 + X < C2). InstCombine will
// create this form when combining two checks of the form (X u< C2 + C1) and
// (X >=u C1).
- auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap]() {
+ auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
+ &ExprsToRewrite]() {
auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
if (!AddExpr || AddExpr->getNumOperands() != 2)
return false;
RewriteMap[LHSUnknown] = getUMaxExpr(
getConstant(ExactRegion.getUnsignedMin()),
getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
+ ExprsToRewrite.push_back(LHSUnknown);
return true;
};
if (MatchRangeCheckIdiom())
return;
- // For now, limit to conditions that provide information about unknown
- // expressions. RHS also cannot contain add recurrences.
- auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS);
- if (!LHSUnknown || containsAddRecurrence(RHS))
+ // If RHS is SCEVUnknown, make sure the information is applied to it.
+ if (isa<SCEVUnknown>(RHS)) {
+ std::swap(LHS, RHS);
+ Predicate = CmpInst::getSwappedPredicate(Predicate);
+ }
+ // If LHS is a constant, apply information to the other expression.
+ if (isa<SCEVConstant>(LHS)) {
+ std::swap(LHS, RHS);
+ Predicate = CmpInst::getSwappedPredicate(Predicate);
+ }
+ // Do not apply information for constants or if RHS contains an AddRec.
+ if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
+ return;
+
+ // Limit to expressions that can be rewritten.
+ if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
return;
// Check whether LHS has already been rewritten. In that case we want to
// chain further rewrites onto the already rewritten value.
auto I = RewriteMap.find(LHS);
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
+
const SCEV *RewrittenRHS = nullptr;
switch (Predicate) {
case CmpInst::ICMP_ULT:
break;
}
- if (RewrittenRHS)
+ if (RewrittenRHS) {
RewriteMap[LHS] = RewrittenRHS;
+ if (LHS == RewrittenLHS)
+ ExprsToRewrite.push_back(LHS);
+ }
};
// Starting at the loop predecessor, climb up the predecessor chain, as long
// as there are predecessors that can be found that have unique successors
if (RewriteMap.empty())
return Expr;
+
+ // Now that all rewrite information is collect, rewrite the collected
+ // expressions with the information in the map. This applies information to
+ // sub-expressions.
+ if (ExprsToRewrite.size() > 1) {
+ for (const SCEV *Expr : ExprsToRewrite) {
+ const SCEV *RewriteTo = RewriteMap[Expr];
+ RewriteMap.erase(Expr);
+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+ RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
+ }
+ }
+
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
return Rewriter.visit(Expr);
}
define void @rewrite_zext(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 2
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_min_max
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 4611686018427387903
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
define void @rewrite_zext_and_base_1(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
define void @rewrite_zext_and_base_2(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1