[SCEV] Track backedge taken count users (NFCI)
authorNikita Popov <nikita.ppv@gmail.com>
Mon, 29 Nov 2021 20:02:37 +0000 (21:02 +0100)
committerNikita Popov <nikita.ppv@gmail.com>
Wed, 1 Dec 2021 09:16:47 +0000 (10:16 +0100)
Track which SCEVs are used as ExactNotTaken counts in
BackedgeTakenInfo structures, so we can directly determine which
loops need to be invalidated, rather than iterating over all BECounts.

This gives a small compile-time improvement on average, but the
motivation here is more to ensure there are no degenerate cases,
if the number of backedge taken counts is large.

Differential Revision: https://reviews.llvm.org/D114784

llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp

index 73faa0a..df50611 100644 (file)
@@ -1378,6 +1378,8 @@ private:
   /// includes an exact count and a maximum count.
   ///
   class BackedgeTakenInfo {
+    friend class ScalarEvolution;
+
     /// A list of computable exits and their not-taken counts.  Loops almost
     /// never have more than one computable exit.
     SmallVector<ExitNotTakenInfo, 1> ExitNotTaken;
@@ -1398,9 +1400,6 @@ private:
     /// True iff the backedge is taken either exactly Max or zero times.
     bool MaxOrZero = false;
 
-    /// SCEV expressions used in any of the ExitNotTakenInfo counts.
-    SmallPtrSet<const SCEV *, 4> Operands;
-
     bool isComplete() const { return IsComplete; }
     const SCEV *getConstantMax() const { return ConstantMax; }
 
@@ -1466,10 +1465,6 @@ private:
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.
     bool isConstantMaxOrZero(ScalarEvolution *SE) const;
-
-    /// Return true if any backedge taken count expressions refer to the given
-    /// subexpression.
-    bool hasOperand(const SCEV *S) const;
   };
 
   /// Cache the backedge-taken count of the loops for this function as they
@@ -1480,6 +1475,10 @@ private:
   /// function as they are computed.
   DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
 
+  /// Loops whose backedge taken counts directly use this non-constant SCEV.
+  DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
+      BECountUsers;
+
   /// This map contains entries for all of the PHI instructions that we
   /// attempt to compute constant evolutions for.  This allows us to avoid
   /// potentially expensive recomputation of these properties.  An instruction
@@ -1911,6 +1910,9 @@ private:
   bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R,
                       SCEV::NoWrapFlags &Flags);
 
+  /// Forget predicated/non-predicated backedge taken counts for the given loop.
+  void forgetBackedgeTakenCounts(const Loop *L, bool Predicated);
+
   /// Drop memoized information for all \p SCEVs.
   void forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs);
 
index ece1909..7dc7f99 100644 (file)
@@ -7603,6 +7603,7 @@ void ScalarEvolution::forgetAllLoops() {
   // result.
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
+  BECountUsers.clear();
   LoopPropertiesCache.clear();
   ConstantEvolutionLoopExitValue.clear();
   ValueExprMap.clear();
@@ -7629,8 +7630,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
     auto *CurrL = LoopWorklist.pop_back_val();
 
     // Drop any stored trip count value.
-    BackedgeTakenCounts.erase(CurrL);
-    PredicatedBackedgeTakenCounts.erase(CurrL);
+    forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
+    forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
 
     // Drop information about predicated SCEV rewrites for this loop.
     for (auto I = PredicatedSCEVRewrites.begin();
@@ -7804,10 +7805,6 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
   return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
 }
 
-bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S) const {
-  return Operands.contains(S);
-}
-
 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
     : ExitLimit(E, E, false, None) {
 }
@@ -7848,19 +7845,6 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
     : ExitLimit(E, M, MaxOrZero, None) {
 }
 
-class SCEVRecordOperands {
-  SmallPtrSetImpl<const SCEV *> &Operands;
-
-public:
-  SCEVRecordOperands(SmallPtrSetImpl<const SCEV *> &Operands)
-    : Operands(Operands) {}
-  bool follow(const SCEV *S) {
-    Operands.insert(S);
-    return true;
-  }
-  bool isDone() { return false; }
-};
-
 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
 /// computable exit into a persistent ExitNotTakenInfo array.
 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
@@ -7889,14 +7873,6 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
   assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
           isa<SCEVConstant>(ConstantMax)) &&
          "No point in having a non-constant max backedge taken count!");
-
-  SCEVRecordOperands RecordOperands(Operands);
-  SCEVTraversal<SCEVRecordOperands> ST(RecordOperands);
-  if (!isa<SCEVCouldNotCompute>(ConstantMax))
-    ST.visitAll(ConstantMax);
-  for (auto &ENT : ExitNotTaken)
-    if (!isa<SCEVCouldNotCompute>(ENT.ExactNotTaken))
-      ST.visitAll(ENT.ExactNotTaken);
 }
 
 /// Compute the number of times the backedge of the specified loop will execute.
@@ -7978,6 +7954,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
   // The loop backedge will be taken the maximum or zero times if there's
   // a single exit that must be taken the maximum or zero times.
   bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
+
+  // Remember which SCEVs are used in exit limits for invalidation purposes.
+  // We only care about non-constant SCEVs here, so we can ignore EL.MaxNotTaken
+  // and MaxBECount, which must be SCEVConstant.
+  for (const auto &Pair : ExitCounts)
+    if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
+      BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
   return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
                            MaxBECount, MaxOrZero);
 }
@@ -12466,6 +12449,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
       BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
       PredicatedBackedgeTakenCounts(
           std::move(Arg.PredicatedBackedgeTakenCounts)),
+      BECountUsers(std::move(Arg.BECountUsers)),
       ConstantEvolutionLoopExitValue(
           std::move(Arg.ConstantEvolutionLoopExitValue)),
       ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
@@ -12882,6 +12866,23 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
   return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
 }
 
+void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
+                                                bool Predicated) {
+  auto &BECounts =
+      Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
+  auto It = BECounts.find(L);
+  if (It != BECounts.end()) {
+    for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
+      if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
+        auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
+        assert(UserIt != BECountUsers.end());
+        UserIt->second.erase({L, Predicated});
+      }
+    }
+    BECounts.erase(It);
+  }
+}
+
 void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
   SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
   SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
@@ -12906,21 +12907,6 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
     else
       ++I;
   }
-
-  auto RemoveSCEVFromBackedgeMap = [&ToForget](
-      DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
-        for (auto I = Map.begin(), E = Map.end(); I != E;) {
-          BackedgeTakenInfo &BEInfo = I->second;
-          if (any_of(ToForget,
-                     [&BEInfo](const SCEV *S) { return BEInfo.hasOperand(S); }))
-            Map.erase(I++);
-          else
-            ++I;
-        }
-  };
-
-  RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
-  RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
 }
 
 void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
@@ -12958,6 +12944,15 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
       erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
     ValuesAtScopesUsers.erase(ScopeUserIt);
   }
+
+  auto BEUsersIt = BECountUsers.find(S);
+  if (BEUsersIt != BECountUsers.end()) {
+    // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
+    auto Copy = BEUsersIt->second;
+    for (const auto &Pair : Copy)
+      forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
+    BECountUsers.erase(BEUsersIt);
+  }
 }
 
 void
@@ -13144,10 +13139,31 @@ void ScalarEvolution::verify() const {
           is_contained(It->second, std::make_pair(L, ValueAtScope)))
         continue;
       dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
-             << ValueAtScope << " missing in ValuesAtScopes";
+             << ValueAtScope << " missing in ValuesAtScopes\n";
       std::abort();
     }
   }
+
+  // Verify integrity of BECountUsers.
+  auto VerifyBECountUsers = [&](bool Predicated) {
+    auto &BECounts =
+        Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
+    for (const auto &LoopAndBEInfo : BECounts) {
+      for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
+        if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
+          auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
+          if (UserIt != BECountUsers.end() &&
+              UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
+            continue;
+          dbgs() << "Value " << *ENT.ExactNotTaken << " for loop "
+                 << *LoopAndBEInfo.first << " missing from BECountUsers\n";
+          std::abort();
+        }
+      }
+    }
+  };
+  VerifyBECountUsers(/* Predicated */ false);
+  VerifyBECountUsers(/* Predicated */ true);
 }
 
 bool ScalarEvolution::invalidate(