[LAA] Move CheckingPtrGroup/PointerCheck outside class (NFC).
authorFlorian Hahn <flo@fhahn.com>
Tue, 28 Apr 2020 20:35:02 +0000 (21:35 +0100)
committerFlorian Hahn <flo@fhahn.com>
Tue, 28 Apr 2020 20:47:31 +0000 (21:47 +0100)
This allows forward declarations of PointerCheck, which in turn reduce
the number of times LoopAccessAnalysis needs to be included.

Ultimately this helps with moving runtime check generation to
Transforms/Utils/LoopUtils.h, without having to include it there.

Reviewers: anemet, Ayal

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D78458

llvm/include/llvm/Analysis/LoopAccessAnalysis.h
llvm/include/llvm/Transforms/Utils/LoopVersioning.h
llvm/lib/Analysis/LoopAccessAnalysis.cpp
llvm/lib/Transforms/Scalar/LoopDistribute.cpp
llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp
llvm/lib/Transforms/Utils/LoopUtils.cpp
llvm/lib/Transforms/Utils/LoopVersioning.cpp

index dc950a9..26ddf92 100644 (file)
@@ -324,9 +324,45 @@ private:
   void mergeInStatus(VectorizationSafetyStatus S);
 };
 
+class RuntimePointerChecking;
+/// A grouping of pointers. A single memcheck is required between
+/// two groups.
+struct RuntimeCheckingPtrGroup {
+  /// Create a new pointer checking group containing a single
+  /// pointer, with index \p Index in RtCheck.
+  RuntimeCheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck);
+
+  /// Tries to add the pointer recorded in RtCheck at index
+  /// \p Index to this pointer checking group. We can only add a pointer
+  /// to a checking group if we will still be able to get
+  /// the upper and lower bounds of the check. Returns true in case
+  /// of success, false otherwise.
+  bool addPointer(unsigned Index);
+
+  /// Constitutes the context of this pointer checking group. For each
+  /// pointer that is a member of this group we will retain the index
+  /// at which it appears in RtCheck.
+  RuntimePointerChecking &RtCheck;
+  /// The SCEV expression which represents the upper bound of all the
+  /// pointers in this group.
+  const SCEV *High;
+  /// The SCEV expression which represents the lower bound of all the
+  /// pointers in this group.
+  const SCEV *Low;
+  /// Indices of all the pointers that constitute this grouping.
+  SmallVector<unsigned, 2> Members;
+};
+
+/// A memcheck which made up of a pair of grouped pointers.
+typedef std::pair<const RuntimeCheckingPtrGroup *,
+                  const RuntimeCheckingPtrGroup *>
+    RuntimePointerCheck;
+
 /// Holds information about the memory runtime legality checks to verify
 /// that a group of pointers do not overlap.
 class RuntimePointerChecking {
+  friend struct RuntimeCheckingPtrGroup;
+
 public:
   struct PointerInfo {
     /// Holds the pointer value that we need to check.
@@ -376,59 +412,20 @@ public:
   /// No run-time memory checking is necessary.
   bool empty() const { return Pointers.empty(); }
 
-  /// A grouping of pointers. A single memcheck is required between
-  /// two groups.
-  struct CheckingPtrGroup {
-    /// Create a new pointer checking group containing a single
-    /// pointer, with index \p Index in RtCheck.
-    CheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck)
-        : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End),
-          Low(RtCheck.Pointers[Index].Start) {
-      Members.push_back(Index);
-    }
-
-    /// Tries to add the pointer recorded in RtCheck at index
-    /// \p Index to this pointer checking group. We can only add a pointer
-    /// to a checking group if we will still be able to get
-    /// the upper and lower bounds of the check. Returns true in case
-    /// of success, false otherwise.
-    bool addPointer(unsigned Index);
-
-    /// Constitutes the context of this pointer checking group. For each
-    /// pointer that is a member of this group we will retain the index
-    /// at which it appears in RtCheck.
-    RuntimePointerChecking &RtCheck;
-    /// The SCEV expression which represents the upper bound of all the
-    /// pointers in this group.
-    const SCEV *High;
-    /// The SCEV expression which represents the lower bound of all the
-    /// pointers in this group.
-    const SCEV *Low;
-    /// Indices of all the pointers that constitute this grouping.
-    SmallVector<unsigned, 2> Members;
-  };
-
-  /// A memcheck which made up of a pair of grouped pointers.
-  ///
-  /// These *have* to be const for now, since checks are generated from
-  /// CheckingPtrGroups in LAI::addRuntimeChecks which is a const member
-  /// function.  FIXME: once check-generation is moved inside this class (after
-  /// the PtrPartition hack is removed), we could drop const.
-  typedef std::pair<const CheckingPtrGroup *, const CheckingPtrGroup *>
-      PointerCheck;
-
   /// Generate the checks and store it.  This also performs the grouping
   /// of pointers to reduce the number of memchecks necessary.
   void generateChecks(MemoryDepChecker::DepCandidates &DepCands,
                       bool UseDependencies);
 
   /// Returns the checks that generateChecks created.
-  const SmallVector<PointerCheck, 4> &getChecks() const { return Checks; }
+  const SmallVector<RuntimePointerCheck, 4> &getChecks() const {
+    return Checks;
+  }
 
   /// Decide if we need to add a check between two groups of pointers,
   /// according to needsChecking.
-  bool needsChecking(const CheckingPtrGroup &M,
-                     const CheckingPtrGroup &N) const;
+  bool needsChecking(const RuntimeCheckingPtrGroup &M,
+                     const RuntimeCheckingPtrGroup &N) const;
 
   /// Returns the number of run-time checks required according to
   /// needsChecking.
@@ -438,7 +435,8 @@ public:
   void print(raw_ostream &OS, unsigned Depth = 0) const;
 
   /// Print \p Checks.
-  void printChecks(raw_ostream &OS, const SmallVectorImpl<PointerCheck> &Checks,
+  void printChecks(raw_ostream &OS,
+                   const SmallVectorImpl<RuntimePointerCheck> &Checks,
                    unsigned Depth = 0) const;
 
   /// This flag indicates if we need to add the runtime check.
@@ -448,7 +446,7 @@ public:
   SmallVector<PointerInfo, 2> Pointers;
 
   /// Holds a partitioning of pointers into "check groups".
-  SmallVector<CheckingPtrGroup, 2> CheckingGroups;
+  SmallVector<RuntimeCheckingPtrGroup, 2> CheckingGroups;
 
   /// Check if pointers are in the same partition
   ///
@@ -476,15 +474,14 @@ private:
                    bool UseDependencies);
 
   /// Generate the checks and return them.
-  SmallVector<PointerCheck, 4>
-  generateChecks() const;
+  SmallVector<RuntimePointerCheck, 4> generateChecks() const;
 
   /// Holds a pointer to the ScalarEvolution analysis.
   ScalarEvolution *SE;
 
   /// Set of run-time checks required to establish independence of
   /// otherwise may-aliasing pointers in the loop.
-  SmallVector<PointerCheck, 4> Checks;
+  SmallVector<RuntimePointerCheck, 4> Checks;
 };
 
 /// Drive the analysis of memory accesses in the loop
@@ -557,10 +554,9 @@ public:
   /// Returns a pair of instructions where the first element is the first
   /// instruction generated in possibly a sequence of instructions and the
   /// second value is the final comparator value or NULL if no check is needed.
-  std::pair<Instruction *, Instruction *>
-  addRuntimeChecks(Instruction *Loc,
-                   const SmallVectorImpl<RuntimePointerChecking::PointerCheck>
-                       &PointerChecks) const;
+  std::pair<Instruction *, Instruction *> addRuntimeChecks(
+      Instruction *Loc,
+      const SmallVectorImpl<RuntimePointerCheck> &PointerChecks) const;
 
   /// The diagnostics report generated for the analysis.  E.g. why we
   /// couldn't analyze the loop.
index 355c4d7..650d3ab 100644 (file)
@@ -15,7 +15,6 @@
 #ifndef LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H
 #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H
 
-#include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
@@ -26,6 +25,10 @@ class Loop;
 class LoopAccessInfo;
 class LoopInfo;
 class ScalarEvolution;
+struct RuntimeCheckingPtrGroup;
+typedef std::pair<const RuntimeCheckingPtrGroup *,
+                  const RuntimeCheckingPtrGroup *>
+    RuntimePointerCheck;
 
 /// This class emits a version of the loop where run-time checks ensure
 /// that may-alias pointers can't overlap.
@@ -71,8 +74,7 @@ public:
   Loop *getNonVersionedLoop() { return NonVersionedLoop; }
 
   /// Sets the runtime alias checks for versioning the loop.
-  void setAliasChecks(
-      SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks);
+  void setAliasChecks(SmallVector<RuntimePointerCheck, 4> Checks);
 
   /// Sets the runtime SCEV checks for versioning the loop.
   void setSCEVChecks(SCEVUnionPredicate Check);
@@ -122,22 +124,20 @@ private:
   ValueToValueMapTy VMap;
 
   /// The set of alias checks that we are versioning for.
-  SmallVector<RuntimePointerChecking::PointerCheck, 4> AliasChecks;
+  SmallVector<RuntimePointerCheck, 4> AliasChecks;
 
   /// The set of SCEV checks that we are versioning for.
   SCEVUnionPredicate Preds;
 
   /// Maps a pointer to the pointer checking group that the pointer
   /// belongs to.
-  DenseMap<const Value *, const RuntimePointerChecking::CheckingPtrGroup *>
-      PtrToGroup;
+  DenseMap<const Value *, const RuntimeCheckingPtrGroup *> PtrToGroup;
 
   /// The alias scope corresponding to a pointer checking group.
-  DenseMap<const RuntimePointerChecking::CheckingPtrGroup *, MDNode *>
-      GroupToScope;
+  DenseMap<const RuntimeCheckingPtrGroup *, MDNode *> GroupToScope;
 
   /// The list of alias scopes that a pointer checking group can't alias.
-  DenseMap<const RuntimePointerChecking::CheckingPtrGroup *, MDNode *>
+  DenseMap<const RuntimeCheckingPtrGroup *, MDNode *>
       GroupToNonAliasingScopeList;
 
   /// Analyses used.
index 05f6010..cba3558 100644 (file)
@@ -174,6 +174,13 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
   return OrigSCEV;
 }
 
+RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
+    unsigned Index, RuntimePointerChecking &RtCheck)
+    : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End),
+      Low(RtCheck.Pointers[Index].Start) {
+  Members.push_back(Index);
+}
+
 /// Calculate Start and End points of memory access.
 /// Let's assume A is the first access and B is a memory access on N-th loop
 /// iteration. Then B is calculated as:
@@ -231,14 +238,14 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr,
   Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc);
 }
 
-SmallVector<RuntimePointerChecking::PointerCheck, 4>
+SmallVector<RuntimePointerCheck, 4>
 RuntimePointerChecking::generateChecks() const {
-  SmallVector<PointerCheck, 4> Checks;
+  SmallVector<RuntimePointerCheck, 4> Checks;
 
   for (unsigned I = 0; I < CheckingGroups.size(); ++I) {
     for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) {
-      const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I];
-      const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J];
+      const RuntimeCheckingPtrGroup &CGI = CheckingGroups[I];
+      const RuntimeCheckingPtrGroup &CGJ = CheckingGroups[J];
 
       if (needsChecking(CGI, CGJ))
         Checks.push_back(std::make_pair(&CGI, &CGJ));
@@ -254,8 +261,8 @@ void RuntimePointerChecking::generateChecks(
   Checks = generateChecks();
 }
 
-bool RuntimePointerChecking::needsChecking(const CheckingPtrGroup &M,
-                                           const CheckingPtrGroup &N) const {
+bool RuntimePointerChecking::needsChecking(
+    const RuntimeCheckingPtrGroup &M, const RuntimeCheckingPtrGroup &N) const {
   for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I)
     for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J)
       if (needsChecking(M.Members[I], N.Members[J]))
@@ -277,7 +284,7 @@ static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J,
   return I;
 }
 
-bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) {
+bool RuntimeCheckingPtrGroup::addPointer(unsigned Index) {
   const SCEV *Start = RtCheck.Pointers[Index].Start;
   const SCEV *End = RtCheck.Pointers[Index].End;
 
@@ -352,7 +359,7 @@ void RuntimePointerChecking::groupChecks(
   // pointers to the same underlying object.
   if (!UseDependencies) {
     for (unsigned I = 0; I < Pointers.size(); ++I)
-      CheckingGroups.push_back(CheckingPtrGroup(I, *this));
+      CheckingGroups.push_back(RuntimeCheckingPtrGroup(I, *this));
     return;
   }
 
@@ -378,7 +385,7 @@ void RuntimePointerChecking::groupChecks(
     MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue,
                                            Pointers[I].IsWritePtr);
 
-    SmallVector<CheckingPtrGroup, 2> Groups;
+    SmallVector<RuntimeCheckingPtrGroup, 2> Groups;
     auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access));
 
     // Because DepCands is constructed by visiting accesses in the order in
@@ -395,7 +402,7 @@ void RuntimePointerChecking::groupChecks(
 
       // Go through all the existing sets and see if we can find one
       // which can include this pointer.
-      for (CheckingPtrGroup &Group : Groups) {
+      for (RuntimeCheckingPtrGroup &Group : Groups) {
         // Don't perform more than a certain amount of comparisons.
         // This should limit the cost of grouping the pointers to something
         // reasonable.  If we do end up hitting this threshold, the algorithm
@@ -415,7 +422,7 @@ void RuntimePointerChecking::groupChecks(
         // We couldn't add this pointer to any existing set or the threshold
         // for the number of comparisons has been reached. Create a new group
         // to hold the current pointer.
-        Groups.push_back(CheckingPtrGroup(Pointer, *this));
+        Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this));
     }
 
     // We've computed the grouped checks for this partition.
@@ -451,7 +458,7 @@ bool RuntimePointerChecking::needsChecking(unsigned I, unsigned J) const {
 }
 
 void RuntimePointerChecking::printChecks(
-    raw_ostream &OS, const SmallVectorImpl<PointerCheck> &Checks,
+    raw_ostream &OS, const SmallVectorImpl<RuntimePointerCheck> &Checks,
     unsigned Depth) const {
   unsigned N = 0;
   for (const auto &Check : Checks) {
@@ -2142,10 +2149,10 @@ struct PointerBounds {
 
 /// Expand code for the lower and upper bound of the pointer group \p CG
 /// in \p TheLoop.  \return the values for the bounds.
-static PointerBounds
-expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop,
-             Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE,
-             const RuntimePointerChecking &PtrRtChecking) {
+static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
+                                  Loop *TheLoop, Instruction *Loc,
+                                  SCEVExpander &Exp, ScalarEvolution *SE,
+                                  const RuntimePointerChecking &PtrRtChecking) {
   Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue;
   const SCEV *Sc = SE->getSCEV(Ptr);
 
@@ -2181,17 +2188,17 @@ expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop,
 
 /// Turns a collection of checks into a collection of expanded upper and
 /// lower bounds for both pointers in the check.
-static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(
-    const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks,
-    Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp,
-    const RuntimePointerChecking &PtrRtChecking) {
+static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
+expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
+             Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp,
+             const RuntimePointerChecking &PtrRtChecking) {
   SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
 
   // Here we're relying on the SCEV Expander's cache to only emit code for the
   // same bounds once.
   transform(
       PointerChecks, std::back_inserter(ChecksWithBounds),
-      [&](const RuntimePointerChecking::PointerCheck &Check) {
+      [&](const RuntimePointerCheck &Check) {
         PointerBounds
           First = expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking),
           Second = expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking);
@@ -2203,8 +2210,7 @@ static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(
 
 std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks(
     Instruction *Loc,
-    const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks)
-    const {
+    const SmallVectorImpl<RuntimePointerCheck> &PointerChecks) const {
   const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout();
   auto *SE = PSE->getSE();
   SCEVExpander Exp(*SE, DL, "induction");
index 8e04e6e..8f6c1d9 100644 (file)
@@ -903,15 +903,14 @@ private:
   /// \p PtrToPartition contains the partition number for pointers.  Partition
   /// number -1 means that the pointer is used in multiple partitions.  In this
   /// case we can't safely omit the check.
-  SmallVector<RuntimePointerChecking::PointerCheck, 4>
-  includeOnlyCrossPartitionChecks(
-      const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &AllChecks,
+  SmallVector<RuntimePointerCheck, 4> includeOnlyCrossPartitionChecks(
+      const SmallVectorImpl<RuntimePointerCheck> &AllChecks,
       const SmallVectorImpl<int> &PtrToPartition,
       const RuntimePointerChecking *RtPtrChecking) {
-    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+    SmallVector<RuntimePointerCheck, 4> Checks;
 
     copy_if(AllChecks, std::back_inserter(Checks),
-            [&](const RuntimePointerChecking::PointerCheck &Check) {
+            [&](const RuntimePointerCheck &Check) {
               for (unsigned PtrIdx1 : Check.first->Members)
                 for (unsigned PtrIdx2 : Check.second->Members)
                   // Only include this check if there is a pair of pointers
index 78460bf..c98d652 100644 (file)
@@ -377,7 +377,7 @@ public:
 
   /// Determine the pointer alias checks to prove that there are no
   /// intervening stores.
-  SmallVector<RuntimePointerChecking::PointerCheck, 4> collectMemchecks(
+  SmallVector<RuntimePointerCheck, 4> collectMemchecks(
       const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) {
 
     SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath =
@@ -391,10 +391,10 @@ public:
                    std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr));
 
     const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();
-    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks;
+    SmallVector<RuntimePointerCheck, 4> Checks;
 
     copy_if(AllChecks, std::back_inserter(Checks),
-            [&](const RuntimePointerChecking::PointerCheck &Check) {
+            [&](const RuntimePointerCheck &Check) {
               for (auto PtrIdx1 : Check.first->Members)
                 for (auto PtrIdx2 : Check.second->Members)
                   if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,
@@ -520,8 +520,7 @@ public:
 
     // Check intervening may-alias stores.  These need runtime checks for alias
     // disambiguation.
-    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks =
-        collectMemchecks(Candidates);
+    SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates);
 
     // Too many checks are likely to outweigh the benefits of forwarding.
     if (Checks.size() > Candidates.size() * CheckPerElim) {
index ab0b5a5..a5fbdb5 100644 (file)
@@ -23,6 +23,7 @@
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/MemorySSA.h"
index 50752bd..2ed54d5 100644 (file)
@@ -45,7 +45,7 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
 }
 
 void LoopVersioning::setAliasChecks(
-    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) {
+    SmallVector<RuntimePointerCheck, 4> Checks) {
   AliasChecks = std::move(Checks);
 }
 
@@ -194,8 +194,7 @@ void LoopVersioning::prepareNoAliasMetadata() {
 
   // Go through the checks and for each pointer group, collect the scopes for
   // each non-aliasing pointer group.
-  DenseMap<const RuntimePointerChecking::CheckingPtrGroup *,
-           SmallVector<Metadata *, 4>>
+  DenseMap<const RuntimeCheckingPtrGroup *, SmallVector<Metadata *, 4>>
       GroupToNonAliasingScopes;
 
   for (const auto &Check : AliasChecks)