From: Max Kazantsev Date: Thu, 1 Oct 2020 08:58:31 +0000 (+0700) Subject: [SCEV] Prove implicaitons via AddRec start X-Git-Tag: llvmorg-13-init~10416 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=69acdfe075fa8eb18781f88f4d0cd1ea40fa6e48;p=platform%2Fupstream%2Fllvm.git [SCEV] Prove implicaitons via AddRec start If we know that some predicate is true for AddRec and an invariant (w.r.t. this AddRec's loop), this fact is, in particular, true on the first iteration. We can try to prove the facts we need using the start value. The motivating example is proving things like ``` isImpliedCondOperands(>=, X, 0, {X,+,-1}, 0} ``` Differential Revision: https://reviews.llvm.org/D88208 Reviewed By: reames --- diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index febca47..158257a 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1677,23 +1677,30 @@ private: getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) const; /// Test whether the condition described by Pred, LHS, and RHS is true - /// whenever the given FoundCondValue value evaluates to true. + /// whenever the given FoundCondValue value evaluates to true in given + /// Context. If Context is nullptr, then the found predicate is true + /// everywhere. bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const Value *FoundCondValue, bool Inverse); + const Value *FoundCondValue, bool Inverse, + const Instruction *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is - /// true. + /// true in given Context. If Context is nullptr, then the found predicate is + /// true everywhere. bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, - const SCEV *FoundRHS); + const SCEV *FoundRHS, + const Instruction *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is - /// true. + /// true in given Context. If Context is nullptr, then the found predicate is + /// true everywhere. bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS); + const SCEV *FoundRHS, + const Instruction *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is @@ -1744,6 +1751,18 @@ private: /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. /// + /// This routine tries to weaken the known condition basing on fact that + /// FoundLHS is an AddRec. + bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, + const Instruction *Context); + + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. + /// /// This routine tries to figure out predicate for Phis which are SCEVUnknown /// if it is true for every possible incoming value from their respective /// basic blocks. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e51b316..a3e454f 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9549,15 +9549,16 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, // Try to prove (Pred, LHS, RHS) using isImpliedCond. auto ProveViaCond = [&](const Value *Condition, bool Inverse) { - if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse)) + const Instruction *Context = &BB->front(); + if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context)) return true; if (ProvingStrictComparison) { if (!ProvedNonStrictComparison) - ProvedNonStrictComparison = - isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse); + ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS, + Condition, Inverse, Context); if (!ProvedNonEquality) - ProvedNonEquality = - isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse); + ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, + Condition, Inverse, Context); if (ProvedNonStrictComparison && ProvedNonEquality) return true; } @@ -9623,7 +9624,8 @@ bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const Value *FoundCondValue, bool Inverse) { + const Value *FoundCondValue, bool Inverse, + const Instruction *Context) { if (!PendingLoopPredicates.insert(FoundCondValue).second) return false; @@ -9634,12 +9636,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, if (const BinaryOperator *BO = dyn_cast(FoundCondValue)) { if (BO->getOpcode() == Instruction::And) { if (!Inverse) - return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || - isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, + Context) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, + Context); } else if (BO->getOpcode() == Instruction::Or) { if (Inverse) - return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || - isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, + Context) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, + Context); } } @@ -9657,14 +9663,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); - return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); + return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context); } bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { + const SCEV *FoundLHS, const SCEV *FoundRHS, + const Instruction *Context) { // Balance the types. if (getTypeSizeInBits(LHS->getType()) < getTypeSizeInBits(FoundLHS->getType())) { @@ -9708,16 +9714,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, // Check whether the found predicate is the same as the desired predicate. if (FoundPred == Pred) - return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); // Check whether swapping the found predicate makes it the same as the // desired predicate. if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { if (isa(RHS)) - return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); + return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context); else - return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), - RHS, LHS, FoundLHS, FoundRHS); + return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS, + LHS, FoundLHS, FoundRHS, Context); } // Unsigned comparison is the same as signed comparison when both the operands @@ -9725,7 +9731,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, if (CmpInst::isUnsigned(FoundPred) && CmpInst::getSignedPredicate(FoundPred) == Pred && isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) - return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); // Check if we can make progress by sharpening ranges. if (FoundPred == ICmpInst::ICMP_NE && @@ -9762,8 +9768,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, case ICmpInst::ICMP_UGE: // We know V `Pred` SharperMin. If this implies LHS `Pred` // RHS, we're done. - if (isImpliedCondOperands(Pred, LHS, RHS, V, - getConstant(SharperMin))) + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin), + Context)) return true; LLVM_FALLTHROUGH; @@ -9778,7 +9784,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, // // If V `Pred` Min implies LHS `Pred` RHS, we're done. - if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), + Context)) return true; break; @@ -9786,14 +9793,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, case ICmpInst::ICMP_SLE: case ICmpInst::ICMP_ULE: if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, - LHS, V, getConstant(SharperMin))) + LHS, V, getConstant(SharperMin), Context)) return true; LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, - LHS, V, getConstant(Min))) + LHS, V, getConstant(Min), Context)) return true; break; @@ -9807,11 +9814,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) - if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) + if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context)) return true; if (Pred == ICmpInst::ICMP_NE) if (!ICmpInst::isTrueWhenEqual(FoundPred)) - if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) + if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, + Context)) return true; // Otherwise assume the worst. @@ -9890,6 +9898,44 @@ Optional ScalarEvolution::computeConstantDifference(const SCEV *More, return None; } +bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) { + // Try to recognize the following pattern: + // + // FoundRHS = ... + // ... + // loop: + // FoundLHS = {Start,+,W} + // context_bb: // Basic block from the same loop + // known(Pred, FoundLHS, FoundRHS) + // + // If some predicate is known in the context of a loop, it is also known on + // each iteration of this loop, including the first iteration. Therefore, in + // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to + // prove the original pred using this fact. + if (!Context) + return false; + // Make sure AR varies in the context block. + if (auto *AR = dyn_cast(FoundLHS)) { + if (!AR->getLoop()->contains(Context->getParent())) + return false; + if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop())) + return false; + return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS); + } + + if (auto *AR = dyn_cast(FoundRHS)) { + if (!AR->getLoop()->contains(Context)) + return false; + if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop())) + return false; + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart()); + } + + return false; +} + bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { @@ -10080,13 +10126,18 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS) { + const SCEV *FoundRHS, + const Instruction *Context) { if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; + if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS, + Context)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index ff33495..e5ffc21 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1251,4 +1251,36 @@ TEST_F(ScalarEvolutionsTest, SCEVgetExitLimitForGuardedLoop) { }); } +TEST_F(ScalarEvolutionsTest, ImpliedViaAddRecStart) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32* %p) { " + "entry: " + " %x = load i32, i32* %p, !range !0 " + " br label %loop " + "loop: " + " %iv = phi i32 [ %x, %entry], [%iv.next, %backedge] " + " %ne.check = icmp ne i32 %iv, 0 " + " br i1 %ne.check, label %backedge, label %exit " + "backedge: " + " %iv.next = add i32 %iv, -1 " + " br label %loop " + "exit:" + " ret void " + "} " + "!0 = !{i32 0, i32 2147483647}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *X = SE.getSCEV(getInstructionByName(F, "x")); + auto *Context = getInstructionByName(F, "iv.next"); + EXPECT_TRUE(SE.isKnownPredicateAt(ICmpInst::ICMP_NE, X, + SE.getZero(X->getType()), Context)); + }); +} + } // end namespace llvm