From 616657b39c8122f10519f11d011375be35f6cf2e Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Tue, 28 Apr 2020 21:35:02 +0100 Subject: [PATCH] [LAA] Move CheckingPtrGroup/PointerCheck outside class (NFC). This allows forward declarations of PointerCheck, which in turn reduce the number of times LoopAccessAnalysis needs to be included. Ultimately this helps with moving runtime check generation to Transforms/Utils/LoopUtils.h, without having to include it there. Reviewers: anemet, Ayal Reviewed By: Ayal Differential Revision: https://reviews.llvm.org/D78458 --- llvm/include/llvm/Analysis/LoopAccessAnalysis.h | 102 ++++++++++----------- .../include/llvm/Transforms/Utils/LoopVersioning.h | 18 ++-- llvm/lib/Analysis/LoopAccessAnalysis.cpp | 52 ++++++----- llvm/lib/Transforms/Scalar/LoopDistribute.cpp | 9 +- llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp | 9 +- llvm/lib/Transforms/Utils/LoopUtils.cpp | 1 + llvm/lib/Transforms/Utils/LoopVersioning.cpp | 5 +- 7 files changed, 98 insertions(+), 98 deletions(-) diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h index dc950a9..26ddf92 100644 --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -324,9 +324,45 @@ private: void mergeInStatus(VectorizationSafetyStatus S); }; +class RuntimePointerChecking; +/// A grouping of pointers. A single memcheck is required between +/// two groups. +struct RuntimeCheckingPtrGroup { + /// Create a new pointer checking group containing a single + /// pointer, with index \p Index in RtCheck. + RuntimeCheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck); + + /// Tries to add the pointer recorded in RtCheck at index + /// \p Index to this pointer checking group. We can only add a pointer + /// to a checking group if we will still be able to get + /// the upper and lower bounds of the check. Returns true in case + /// of success, false otherwise. + bool addPointer(unsigned Index); + + /// Constitutes the context of this pointer checking group. For each + /// pointer that is a member of this group we will retain the index + /// at which it appears in RtCheck. + RuntimePointerChecking &RtCheck; + /// The SCEV expression which represents the upper bound of all the + /// pointers in this group. + const SCEV *High; + /// The SCEV expression which represents the lower bound of all the + /// pointers in this group. + const SCEV *Low; + /// Indices of all the pointers that constitute this grouping. + SmallVector Members; +}; + +/// A memcheck which made up of a pair of grouped pointers. +typedef std::pair + RuntimePointerCheck; + /// Holds information about the memory runtime legality checks to verify /// that a group of pointers do not overlap. class RuntimePointerChecking { + friend struct RuntimeCheckingPtrGroup; + public: struct PointerInfo { /// Holds the pointer value that we need to check. @@ -376,59 +412,20 @@ public: /// No run-time memory checking is necessary. bool empty() const { return Pointers.empty(); } - /// A grouping of pointers. A single memcheck is required between - /// two groups. - struct CheckingPtrGroup { - /// Create a new pointer checking group containing a single - /// pointer, with index \p Index in RtCheck. - CheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck) - : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), - Low(RtCheck.Pointers[Index].Start) { - Members.push_back(Index); - } - - /// Tries to add the pointer recorded in RtCheck at index - /// \p Index to this pointer checking group. We can only add a pointer - /// to a checking group if we will still be able to get - /// the upper and lower bounds of the check. Returns true in case - /// of success, false otherwise. - bool addPointer(unsigned Index); - - /// Constitutes the context of this pointer checking group. For each - /// pointer that is a member of this group we will retain the index - /// at which it appears in RtCheck. - RuntimePointerChecking &RtCheck; - /// The SCEV expression which represents the upper bound of all the - /// pointers in this group. - const SCEV *High; - /// The SCEV expression which represents the lower bound of all the - /// pointers in this group. - const SCEV *Low; - /// Indices of all the pointers that constitute this grouping. - SmallVector Members; - }; - - /// A memcheck which made up of a pair of grouped pointers. - /// - /// These *have* to be const for now, since checks are generated from - /// CheckingPtrGroups in LAI::addRuntimeChecks which is a const member - /// function. FIXME: once check-generation is moved inside this class (after - /// the PtrPartition hack is removed), we could drop const. - typedef std::pair - PointerCheck; - /// Generate the checks and store it. This also performs the grouping /// of pointers to reduce the number of memchecks necessary. void generateChecks(MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies); /// Returns the checks that generateChecks created. - const SmallVector &getChecks() const { return Checks; } + const SmallVector &getChecks() const { + return Checks; + } /// Decide if we need to add a check between two groups of pointers, /// according to needsChecking. - bool needsChecking(const CheckingPtrGroup &M, - const CheckingPtrGroup &N) const; + bool needsChecking(const RuntimeCheckingPtrGroup &M, + const RuntimeCheckingPtrGroup &N) const; /// Returns the number of run-time checks required according to /// needsChecking. @@ -438,7 +435,8 @@ public: void print(raw_ostream &OS, unsigned Depth = 0) const; /// Print \p Checks. - void printChecks(raw_ostream &OS, const SmallVectorImpl &Checks, + void printChecks(raw_ostream &OS, + const SmallVectorImpl &Checks, unsigned Depth = 0) const; /// This flag indicates if we need to add the runtime check. @@ -448,7 +446,7 @@ public: SmallVector Pointers; /// Holds a partitioning of pointers into "check groups". - SmallVector CheckingGroups; + SmallVector CheckingGroups; /// Check if pointers are in the same partition /// @@ -476,15 +474,14 @@ private: bool UseDependencies); /// Generate the checks and return them. - SmallVector - generateChecks() const; + SmallVector generateChecks() const; /// Holds a pointer to the ScalarEvolution analysis. ScalarEvolution *SE; /// Set of run-time checks required to establish independence of /// otherwise may-aliasing pointers in the loop. - SmallVector Checks; + SmallVector Checks; }; /// Drive the analysis of memory accesses in the loop @@ -557,10 +554,9 @@ public: /// Returns a pair of instructions where the first element is the first /// instruction generated in possibly a sequence of instructions and the /// second value is the final comparator value or NULL if no check is needed. - std::pair - addRuntimeChecks(Instruction *Loc, - const SmallVectorImpl - &PointerChecks) const; + std::pair addRuntimeChecks( + Instruction *Loc, + const SmallVectorImpl &PointerChecks) const; /// The diagnostics report generated for the analysis. E.g. why we /// couldn't analyze the loop. diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h index 355c4d7..650d3ab 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h +++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h @@ -15,7 +15,6 @@ #ifndef LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H -#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -26,6 +25,10 @@ class Loop; class LoopAccessInfo; class LoopInfo; class ScalarEvolution; +struct RuntimeCheckingPtrGroup; +typedef std::pair + RuntimePointerCheck; /// This class emits a version of the loop where run-time checks ensure /// that may-alias pointers can't overlap. @@ -71,8 +74,7 @@ public: Loop *getNonVersionedLoop() { return NonVersionedLoop; } /// Sets the runtime alias checks for versioning the loop. - void setAliasChecks( - SmallVector Checks); + void setAliasChecks(SmallVector Checks); /// Sets the runtime SCEV checks for versioning the loop. void setSCEVChecks(SCEVUnionPredicate Check); @@ -122,22 +124,20 @@ private: ValueToValueMapTy VMap; /// The set of alias checks that we are versioning for. - SmallVector AliasChecks; + SmallVector AliasChecks; /// The set of SCEV checks that we are versioning for. SCEVUnionPredicate Preds; /// Maps a pointer to the pointer checking group that the pointer /// belongs to. - DenseMap - PtrToGroup; + DenseMap PtrToGroup; /// The alias scope corresponding to a pointer checking group. - DenseMap - GroupToScope; + DenseMap GroupToScope; /// The list of alias scopes that a pointer checking group can't alias. - DenseMap + DenseMap GroupToNonAliasingScopeList; /// Analyses used. diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 05f6010..cba3558 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -174,6 +174,13 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, return OrigSCEV; } +RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( + unsigned Index, RuntimePointerChecking &RtCheck) + : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), + Low(RtCheck.Pointers[Index].Start) { + Members.push_back(Index); +} + /// Calculate Start and End points of memory access. /// Let's assume A is the first access and B is a memory access on N-th loop /// iteration. Then B is calculated as: @@ -231,14 +238,14 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); } -SmallVector +SmallVector RuntimePointerChecking::generateChecks() const { - SmallVector Checks; + SmallVector Checks; for (unsigned I = 0; I < CheckingGroups.size(); ++I) { for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) { - const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I]; - const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J]; + const RuntimeCheckingPtrGroup &CGI = CheckingGroups[I]; + const RuntimeCheckingPtrGroup &CGJ = CheckingGroups[J]; if (needsChecking(CGI, CGJ)) Checks.push_back(std::make_pair(&CGI, &CGJ)); @@ -254,8 +261,8 @@ void RuntimePointerChecking::generateChecks( Checks = generateChecks(); } -bool RuntimePointerChecking::needsChecking(const CheckingPtrGroup &M, - const CheckingPtrGroup &N) const { +bool RuntimePointerChecking::needsChecking( + const RuntimeCheckingPtrGroup &M, const RuntimeCheckingPtrGroup &N) const { for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I) for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J) if (needsChecking(M.Members[I], N.Members[J])) @@ -277,7 +284,7 @@ static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J, return I; } -bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) { +bool RuntimeCheckingPtrGroup::addPointer(unsigned Index) { const SCEV *Start = RtCheck.Pointers[Index].Start; const SCEV *End = RtCheck.Pointers[Index].End; @@ -352,7 +359,7 @@ void RuntimePointerChecking::groupChecks( // pointers to the same underlying object. if (!UseDependencies) { for (unsigned I = 0; I < Pointers.size(); ++I) - CheckingGroups.push_back(CheckingPtrGroup(I, *this)); + CheckingGroups.push_back(RuntimeCheckingPtrGroup(I, *this)); return; } @@ -378,7 +385,7 @@ void RuntimePointerChecking::groupChecks( MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue, Pointers[I].IsWritePtr); - SmallVector Groups; + SmallVector Groups; auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access)); // Because DepCands is constructed by visiting accesses in the order in @@ -395,7 +402,7 @@ void RuntimePointerChecking::groupChecks( // Go through all the existing sets and see if we can find one // which can include this pointer. - for (CheckingPtrGroup &Group : Groups) { + for (RuntimeCheckingPtrGroup &Group : Groups) { // Don't perform more than a certain amount of comparisons. // This should limit the cost of grouping the pointers to something // reasonable. If we do end up hitting this threshold, the algorithm @@ -415,7 +422,7 @@ void RuntimePointerChecking::groupChecks( // We couldn't add this pointer to any existing set or the threshold // for the number of comparisons has been reached. Create a new group // to hold the current pointer. - Groups.push_back(CheckingPtrGroup(Pointer, *this)); + Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); } // We've computed the grouped checks for this partition. @@ -451,7 +458,7 @@ bool RuntimePointerChecking::needsChecking(unsigned I, unsigned J) const { } void RuntimePointerChecking::printChecks( - raw_ostream &OS, const SmallVectorImpl &Checks, + raw_ostream &OS, const SmallVectorImpl &Checks, unsigned Depth) const { unsigned N = 0; for (const auto &Check : Checks) { @@ -2142,10 +2149,10 @@ struct PointerBounds { /// Expand code for the lower and upper bound of the pointer group \p CG /// in \p TheLoop. \return the values for the bounds. -static PointerBounds -expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop, - Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE, - const RuntimePointerChecking &PtrRtChecking) { +static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, + Loop *TheLoop, Instruction *Loc, + SCEVExpander &Exp, ScalarEvolution *SE, + const RuntimePointerChecking &PtrRtChecking) { Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue; const SCEV *Sc = SE->getSCEV(Ptr); @@ -2181,17 +2188,17 @@ expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop, /// Turns a collection of checks into a collection of expanded upper and /// lower bounds for both pointers in the check. -static SmallVector, 4> expandBounds( - const SmallVectorImpl &PointerChecks, - Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp, - const RuntimePointerChecking &PtrRtChecking) { +static SmallVector, 4> +expandBounds(const SmallVectorImpl &PointerChecks, Loop *L, + Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp, + const RuntimePointerChecking &PtrRtChecking) { SmallVector, 4> ChecksWithBounds; // Here we're relying on the SCEV Expander's cache to only emit code for the // same bounds once. transform( PointerChecks, std::back_inserter(ChecksWithBounds), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking), Second = expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking); @@ -2203,8 +2210,7 @@ static SmallVector, 4> expandBounds( std::pair LoopAccessInfo::addRuntimeChecks( Instruction *Loc, - const SmallVectorImpl &PointerChecks) - const { + const SmallVectorImpl &PointerChecks) const { const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); auto *SE = PSE->getSE(); SCEVExpander Exp(*SE, DL, "induction"); diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 8e04e6e..8f6c1d9 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -903,15 +903,14 @@ private: /// \p PtrToPartition contains the partition number for pointers. Partition /// number -1 means that the pointer is used in multiple partitions. In this /// case we can't safely omit the check. - SmallVector - includeOnlyCrossPartitionChecks( - const SmallVectorImpl &AllChecks, + SmallVector includeOnlyCrossPartitionChecks( + const SmallVectorImpl &AllChecks, const SmallVectorImpl &PtrToPartition, const RuntimePointerChecking *RtPtrChecking) { - SmallVector Checks; + SmallVector Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { for (unsigned PtrIdx1 : Check.first->Members) for (unsigned PtrIdx2 : Check.second->Members) // Only include this check if there is a pair of pointers diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 78460bf..c98d652 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -377,7 +377,7 @@ public: /// Determine the pointer alias checks to prove that there are no /// intervening stores. - SmallVector collectMemchecks( + SmallVector collectMemchecks( const SmallVectorImpl &Candidates) { SmallPtrSet PtrsWrittenOnFwdingPath = @@ -391,10 +391,10 @@ public: std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr)); const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); - SmallVector Checks; + SmallVector Checks; copy_if(AllChecks, std::back_inserter(Checks), - [&](const RuntimePointerChecking::PointerCheck &Check) { + [&](const RuntimePointerCheck &Check) { for (auto PtrIdx1 : Check.first->Members) for (auto PtrIdx2 : Check.second->Members) if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, @@ -520,8 +520,7 @@ public: // Check intervening may-alias stores. These need runtime checks for alias // disambiguation. - SmallVector Checks = - collectMemchecks(Candidates); + SmallVector Checks = collectMemchecks(Candidates); // Too many checks are likely to outweigh the benefits of forwarding. if (Checks.size() > Candidates.size() * CheckPerElim) { diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index ab0b5a5..a5fbdb5 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 50752bd..2ed54d5 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -45,7 +45,7 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, } void LoopVersioning::setAliasChecks( - SmallVector Checks) { + SmallVector Checks) { AliasChecks = std::move(Checks); } @@ -194,8 +194,7 @@ void LoopVersioning::prepareNoAliasMetadata() { // Go through the checks and for each pointer group, collect the scopes for // each non-aliasing pointer group. - DenseMap> + DenseMap> GroupToNonAliasingScopes; for (const auto &Check : AliasChecks) -- 2.7.4