[AssumptionCache] caches @llvm.experimental.guard's
authorJoshua Cao <cao.joshua@yahoo.com>
Mon, 23 Jan 2023 07:26:37 +0000 (23:26 -0800)
committerJoshua Cao <cao.joshua@yahoo.com>
Wed, 25 Jan 2023 04:16:46 +0000 (20:16 -0800)
As discussed in https://github.com/llvm/llvm-project/issues/59901

This change is not NFC. There is one SCEV and EarlyCSE test that have an
improved analysis/optimization case. Rest of the tests are not failing.

I've mostly only added cleanup to SCEV since that is where this issue
started. As a follow up, I believe there is more cleanup opportunity in
SCEV and other affected passes.

There could be cases where there are missed registerAssumption of
guards, but this case is not so bad because there will be no
miscompilation. AssumptionCacheTracker should take care of deleted
guards.

Differential Revision: https://reviews.llvm.org/D142330

llvm/include/llvm/Analysis/AssumptionCache.h
llvm/include/llvm/IR/IntrinsicInst.h
llvm/lib/Analysis/AssumeBundleQueries.cpp
llvm/lib/Analysis/AssumptionCache.cpp
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/lib/Analysis/ValueTracking.cpp
llvm/lib/Transforms/Utils/CodeExtractor.cpp
llvm/lib/Transforms/Utils/InlineFunction.cpp
llvm/test/Analysis/AssumptionCache/basic.ll
llvm/test/Analysis/ScalarEvolution/guards.ll
llvm/test/Transforms/EarlyCSE/guards.ll

index 12dd9b0..838426d 100644 (file)
@@ -26,7 +26,7 @@
 
 namespace llvm {
 
-class AssumeInst;
+class CondGuardInst;
 class Function;
 class raw_ostream;
 class TargetTransformInfo;
@@ -120,15 +120,15 @@ public:
   ///
   /// The call passed in must be an instruction within this function and must
   /// not already be in the cache.
-  void registerAssumption(AssumeInst *CI);
+  void registerAssumption(CondGuardInst *CI);
 
   /// Remove an \@llvm.assume intrinsic from this function's cache if it has
   /// been added to the cache earlier.
-  void unregisterAssumption(AssumeInst *CI);
+  void unregisterAssumption(CondGuardInst *CI);
 
   /// Update the cache of values being affected by this assumption (i.e.
   /// the values about which this assumption provides information).
-  void updateAffectedValues(AssumeInst *CI);
+  void updateAffectedValues(CondGuardInst *CI);
 
   /// Clear the cache of \@llvm.assume intrinsics for a function.
   ///
index df6ce27..5de0107 100644 (file)
@@ -1513,9 +1513,20 @@ public:
   }
 };
 
+/// This represents intrinsics that guard a condition
+class CondGuardInst : public IntrinsicInst {
+public:
+  static bool classof(const IntrinsicInst *I) {
+    return I->getIntrinsicID() == Intrinsic::assume ||
+           I->getIntrinsicID() == Intrinsic::experimental_guard;
+  }
+  static bool classof(const Value *V) {
+    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+  }
+};
 
 /// This represents the llvm.assume intrinsic.
-class AssumeInst : public IntrinsicInst {
+class AssumeInst : public CondGuardInst {
 public:
   static bool classof(const IntrinsicInst *I) {
     return I->getIntrinsicID() == Intrinsic::assume;
index 7440dbd..110cddb 100644 (file)
@@ -162,7 +162,7 @@ llvm::getKnowledgeForValue(const Value *V,
     return RetainedKnowledge::none();
   if (AC) {
     for (AssumptionCache::ResultElem &Elem : AC->assumptionsFor(V)) {
-      auto *II = cast_or_null<AssumeInst>(Elem.Assume);
+      auto *II = dyn_cast_or_null<AssumeInst>(Elem.Assume);
       if (!II || Elem.Index == AssumptionCache::ExprResultIdx)
         continue;
       if (RetainedKnowledge RK = getKnowledgeFromBundle(
index 11796ef..2d648cc 100644 (file)
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file contains a pass that keeps track of @llvm.assume intrinsics in
-// the functions of a module.
+// This file contains a pass that keeps track of @llvm.assume and
+// @llvm.experimental.guard intrinsics in the functions of a module.
 //
 //===----------------------------------------------------------------------===//
 
@@ -140,7 +140,7 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
   }
 }
 
-void AssumptionCache::updateAffectedValues(AssumeInst *CI) {
+void AssumptionCache::updateAffectedValues(CondGuardInst *CI) {
   SmallVector<AssumptionCache::ResultElem, 16> Affected;
   findAffectedValues(CI, TTI, Affected);
 
@@ -153,7 +153,7 @@ void AssumptionCache::updateAffectedValues(AssumeInst *CI) {
   }
 }
 
-void AssumptionCache::unregisterAssumption(AssumeInst *CI) {
+void AssumptionCache::unregisterAssumption(CondGuardInst *CI) {
   SmallVector<AssumptionCache::ResultElem, 16> Affected;
   findAffectedValues(CI, TTI, Affected);
 
@@ -217,7 +217,7 @@ void AssumptionCache::scanFunction() {
   // to this cache.
   for (BasicBlock &B : F)
     for (Instruction &I : B)
-      if (isa<AssumeInst>(&I))
+      if (isa<CondGuardInst>(&I))
         AssumeHandles.push_back({&I, ExprResultIdx});
 
   // Mark the scan as complete.
@@ -225,10 +225,10 @@ void AssumptionCache::scanFunction() {
 
   // Update affected values.
   for (auto &A : AssumeHandles)
-    updateAffectedValues(cast<AssumeInst>(A));
+    updateAffectedValues(cast<CondGuardInst>(A));
 }
 
-void AssumptionCache::registerAssumption(AssumeInst *CI) {
+void AssumptionCache::registerAssumption(CondGuardInst *CI) {
   // If we haven't scanned the function yet, just drop this assumption. It will
   // be found when we scan later.
   if (!Scanned)
@@ -238,9 +238,9 @@ void AssumptionCache::registerAssumption(AssumeInst *CI) {
 
 #ifndef NDEBUG
   assert(CI->getParent() &&
-         "Cannot register @llvm.assume call not in a basic block");
+         "Cannot a register CondGuardInst not in a basic block");
   assert(&F == CI->getParent()->getParent() &&
-         "Cannot register @llvm.assume call not in this function");
+         "Cannot a register CondGuardInst not in this function");
 
   // We expect the number of assumptions to be small, so in an asserts build
   // check that we don't accumulate duplicates and that all assumptions point
@@ -252,8 +252,8 @@ void AssumptionCache::registerAssumption(AssumeInst *CI) {
 
     assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
            "Cached assumption not inside this function!");
-    assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
-           "Cached something other than a call to @llvm.assume!");
+    assert(isa<CondGuardInst>(VH) &&
+           "Cached something other than CondGuardInst!");
     assert(AssumptionSet.insert(VH).second &&
            "Cache contains multiple copies of a call!");
   }
index 97c80a7..8c62fc3 100644 (file)
@@ -1771,8 +1771,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
       // these to compute max backedge taken counts, but can still use
       // these to prove lack of overflow.  Use this fact to avoid
       // doing extra work that may not pay off.
-      if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
-          !AC.assumptions().empty()) {
+      if (!isa<SCEVCouldNotCompute>(MaxBECount) || !AC.assumptions().empty()) {
 
         auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
         setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
@@ -5113,8 +5112,7 @@ ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
   // these to prove lack of overflow.  Use this fact to avoid
   // doing extra work that may not pay off.
 
-  if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
-      AC.assumptions().empty())
+  if (isa<SCEVCouldNotCompute>(MaxBECount) && AC.assumptions().empty())
     return Result;
 
   // If the backedge is guarded by a comparison with the pre-inc  value the
@@ -5167,8 +5165,7 @@ ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
   // these to prove lack of overflow.  Use this fact to avoid
   // doing extra work that may not pay off.
 
-  if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
-      AC.assumptions().empty())
+  if (isa<SCEVCouldNotCompute>(MaxBECount) && AC.assumptions().empty())
     return Result;
 
   // If the backedge is guarded by a comparison with the pre-inc  value the
@@ -11356,7 +11353,7 @@ bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
                                         ICmpInst::Predicate Pred,
                                         const SCEV *LHS, const SCEV *RHS) {
   // No need to even try if we know the module has no guards.
-  if (!HasGuards)
+  if (AC.assumptions().empty())
     return false;
 
   return any_of(*BB, [&](const Instruction &I) {
@@ -11566,15 +11563,6 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
       return true;
   }
 
-  // Check conditions due to any @llvm.experimental.guard intrinsics.
-  auto *GuardDecl = F.getParent()->getFunction(
-      Intrinsic::getName(Intrinsic::experimental_guard));
-  if (GuardDecl)
-    for (const auto *GU : GuardDecl->users())
-      if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
-        if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
-          if (ProveViaCond(Guard->getArgOperand(0), false))
-            return true;
   return false;
 }
 
@@ -13447,25 +13435,11 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
                                  LoopInfo &LI)
     : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
       CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
-      LoopDispositions(64), BlockDispositions(64) {
-  // To use guards for proving predicates, we need to scan every instruction in
-  // relevant basic blocks, and not just terminators.  Doing this is a waste of
-  // time if the IR does not actually contain any calls to
-  // @llvm.experimental.guard, so do a quick check and remember this beforehand.
-  //
-  // This pessimizes the case where a pass that preserves ScalarEvolution wants
-  // to _add_ guards to the module when there weren't any before, and wants
-  // ScalarEvolution to optimize based on those guards.  For now we prefer to be
-  // efficient in lieu of being smart in that rather obscure case.
-
-  auto *GuardDecl = F.getParent()->getFunction(
-      Intrinsic::getName(Intrinsic::experimental_guard));
-  HasGuards = GuardDecl && !GuardDecl->use_empty();
-}
+      LoopDispositions(64), BlockDispositions(64) {}
 
 ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
-    : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
-      LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
+    : F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI),
+      CouldNotCompute(std::move(Arg.CouldNotCompute)),
       ValueExprMap(std::move(Arg.ValueExprMap)),
       PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
       PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
@@ -15138,16 +15112,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     Terms.emplace_back(AssumeI->getOperand(0), true);
   }
 
-  // Second, collect information from llvm.experimental.guards dominating the loop.
-  auto *GuardDecl = F.getParent()->getFunction(
-      Intrinsic::getName(Intrinsic::experimental_guard));
-  if (GuardDecl)
-    for (const auto *GU : GuardDecl->users())
-      if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
-        if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
-          Terms.emplace_back(Guard->getArgOperand(0), true);
-
-  // Third, collect conditions from dominating branches. Starting at the loop
+  // Second, collect conditions from dominating branches. Starting at the loop
   // predecessor, climb up the predecessor chain, as long as there are
   // predecessors that can be found that have unique successors leading to the
   // original header.
index 260d27b..a13bdad 100644 (file)
@@ -616,7 +616,7 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) {
   for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
     if (!AssumeVH)
       continue;
-    CallInst *I = cast<CallInst>(AssumeVH);
+    CondGuardInst *I = cast<CondGuardInst>(AssumeVH);
     assert(I->getFunction() == Q.CxtI->getFunction() &&
            "Got assumption for the wrong function!");
 
@@ -624,9 +624,6 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) {
     // We're running this loop for once for each value queried resulting in a
     // runtime of ~O(#assumes * #values).
 
-    assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume &&
-           "must be an assume intrinsic");
-
     Value *RHS;
     CmpInst::Predicate Pred;
     auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
@@ -664,7 +661,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
   for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
     if (!AssumeVH)
       continue;
-    CallInst *I = cast<CallInst>(AssumeVH);
+    CondGuardInst *I = cast<CondGuardInst>(AssumeVH);
     assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
            "Got assumption for the wrong function!");
 
@@ -672,9 +669,6 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known,
     // We're running this loop for once for each value queried resulting in a
     // runtime of ~O(#assumes * #values).
 
-    assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume &&
-           "must be an assume intrinsic");
-
     Value *Arg = I->getArgOperand(0);
 
     if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
@@ -7446,11 +7440,9 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
     for (auto &AssumeVH : AC->assumptionsFor(V)) {
       if (!AssumeVH)
         continue;
-      CallInst *I = cast<CallInst>(AssumeVH);
+      IntrinsicInst *I = cast<IntrinsicInst>(AssumeVH);
       assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
              "Got assumption for the wrong function!");
-      assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume &&
-             "must be an assume intrinsic");
 
       if (!isValidAssumeForContext(I, CtxI, DT))
         continue;
index 4d5a2ed..c1fe105 100644 (file)
@@ -1663,14 +1663,14 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
     }
   }
 
-  // Remove @llvm.assume calls that will be moved to the new function from the
-  // old function's assumption cache.
+  // Remove CondGuardInsts that will be moved to the new function from the old
+  // function's assumption cache.
   for (BasicBlock *Block : Blocks) {
     for (Instruction &I : llvm::make_early_inc_range(*Block)) {
-      if (auto *AI = dyn_cast<AssumeInst>(&I)) {
+      if (auto *CI = dyn_cast<CondGuardInst>(&I)) {
         if (AC)
-          AC->unregisterAssumption(AI);
-        AI->eraseFromParent();
+          AC->unregisterAssumption(CI);
+        CI->eraseFromParent();
       }
     }
   }
@@ -1864,7 +1864,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
                                           const Function &NewFunc,
                                           AssumptionCache *AC) {
   for (auto AssumeVH : AC->assumptions()) {
-    auto *I = dyn_cast_or_null<CallInst>(AssumeVH);
+    auto *I = dyn_cast_or_null<CondGuardInst>(AssumeVH);
     if (!I)
       continue;
 
@@ -1876,7 +1876,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
     // that were previously in the old function, but that have now been moved
     // to the new function.
     for (auto AffectedValVH : AC->assumptionsFor(I->getOperand(0))) {
-      auto *AffectedCI = dyn_cast_or_null<CallInst>(AffectedValVH);
+      auto *AffectedCI = dyn_cast_or_null<CondGuardInst>(AffectedValVH);
       if (!AffectedCI)
         continue;
       if (AffectedCI->getFunction() != &OldFunc)
index 61fc373..399c9a4 100644 (file)
@@ -2333,7 +2333,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
       for (BasicBlock &NewBlock :
            make_range(FirstNewBlock->getIterator(), Caller->end()))
         for (Instruction &I : NewBlock)
-          if (auto *II = dyn_cast<AssumeInst>(&I))
+          if (auto *II = dyn_cast<CondGuardInst>(&I))
             IFI.GetAssumptionCache(*Caller).registerAssumption(II);
   }
 
index bd4e7b6..75eb8f3 100644 (file)
@@ -3,12 +3,15 @@
 target datalayout = "e-i64:64-f80:128-n8:16:32:64-S128"
 
 declare void @llvm.assume(i1)
+declare void @llvm.experimental.guard(i1, ...)
 
 define void @test1(i32 %a) {
 ; CHECK-LABEL: Cached assumptions for function: test1
 ; CHECK-NEXT: icmp ne i32 %{{.*}}, 0
 ; CHECK-NEXT: icmp slt i32 %{{.*}}, 0
 ; CHECK-NEXT: icmp sgt i32 %{{.*}}, 0
+; CHECK-NEXT: icmp ult i32 %{{.*}}, 0
+; CHECK-NEXT: icmp ugt i32 %{{.*}}, 0
 
 entry:
   %cond1 = icmp ne i32 %a, 0
@@ -17,6 +20,10 @@ entry:
   call void @llvm.assume(i1 %cond2)
   %cond3 = icmp sgt i32 %a, 0
   call void @llvm.assume(i1 %cond3)
+  %cond4 = icmp ult i32 %a, 0
+  call void (i1, ...) @llvm.experimental.guard(i1 %cond4) [ "deopt"() ]
+  %cond5 = icmp ugt i32 %a, 0
+  call void (i1, ...) @llvm.experimental.guard(i1 %cond5) [ "deopt"() ]
 
   ret void
 }
index 3922775..62b387b 100644 (file)
@@ -86,7 +86,7 @@ entry:
 loop:
 ; CHECK: loop:
 ; CHECK:  call void (i1, ...) @llvm.experimental.guard(i1 true) [ "deopt"() ]
-; CHECK:  %iv.inc.cmp = icmp slt i32 %iv.inc, %len
+; CHECK:  %iv.inc.cmp = icmp ult i32 %iv.inc, %len
 ; CHECK:  call void (i1, ...) @llvm.experimental.guard(i1 %iv.inc.cmp) [ "deopt"() ]
 ; CHECK: leave:
   %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ]
@@ -129,7 +129,7 @@ left:
 
 be:
 ; CHECK: be:
-; CHECK-NEXT:  %iv.cmp = icmp slt i32 %iv, %len
+; CHECK-NEXT:  %iv.cmp = icmp ult i32 %iv, %len
 ; CHECK-NEXT:  call void (i1, ...) @llvm.experimental.guard(i1 %iv.cmp) [ "deopt"() ]
 ; CHECK: leave:
 
index e837b77..64b8a1f 100644 (file)
@@ -83,13 +83,10 @@ define i32 @test3.unhandled(i32 %val) {
 ; CHECK-LABEL: @test3.unhandled(
 ; CHECK-NEXT:    [[COND0:%.*]] = icmp slt i32 [[VAL:%.*]], 40
 ; CHECK-NEXT:    call void (i1, ...) @llvm.experimental.guard(i1 [[COND0]]) [ "deopt"() ]
-; CHECK-NEXT:    [[COND1:%.*]] = icmp sge i32 [[VAL]], 40
-; CHECK-NEXT:    call void (i1, ...) @llvm.experimental.guard(i1 [[COND1]]) [ "deopt"() ]
+; CHECK-NEXT:    call void (i1, ...) @llvm.experimental.guard(i1 false) [ "deopt"() ]
 ; CHECK-NEXT:    ret i32 0
 ;
 
-; Demonstrates a case we do not yet handle (it is legal to fold %cond2
-; to false)
   %cond0 = icmp slt i32 %val, 40
   call void(i1,...) @llvm.experimental.guard(i1 %cond0) [ "deopt"() ]
   %cond1 = icmp sge i32 %val, 40