From 98a137196a4d3f728bddace175528c01fe617fbe Mon Sep 17 00:00:00 2001 From: Silviu Baranga Date: Mon, 8 Jun 2015 10:27:06 +0000 Subject: [PATCH] [LAA] Fix estimation of number of memchecks Summary: We need to add a runtime memcheck for pair of accesses (x,y) where at least one of x and y are writes. Assuming we have w writes and r reads, currently this number is estimated as being w* (w+r-1). This estimation will count (write,write) pairs twice and will overestimate the number of checks required. This change adds a getNumberOfChecks method to RuntimePointerCheck, which will count the number of runtime checks needed (similar in implementation to needsAnyChecking) and uses it to produce the correct number of runtime checks. Test Plan: llvm test suite spec2k spec2k6 Performance results: no changes observed (not surprising since the formula for 1 writer is basically the same, which would covers most cases - at least with the current check limit). Reviewers: anemet Reviewed By: anemet Subscribers: mzolotukhin, llvm-commits Differential Revision: http://reviews.llvm.org/D10217 llvm-svn: 239295 --- llvm/include/llvm/Analysis/LoopAccessAnalysis.h | 13 ++-- llvm/lib/Analysis/LoopAccessAnalysis.cpp | 72 ++++++++++------------ .../LoopAccessAnalysis/number-of-memchecks.ll | 58 +++++++++++++++++ 3 files changed, 100 insertions(+), 43 deletions(-) create mode 100644 llvm/test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h index 06fb082..7b635a8 100644 --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -345,6 +345,10 @@ public: /// to needsChecking. bool needsAnyChecking(const SmallVectorImpl *PtrPartition) const; + /// \brief Returns the number of run-time checks required according to + /// needsChecking. + unsigned getNumberOfChecks(const SmallVectorImpl *PtrPartition) const; + /// \brief Print the list run-time memory checks necessary. /// /// If \p PtrPartition is set, it contains the partition number for @@ -385,7 +389,10 @@ public: /// \brief Number of memchecks required to prove independence of otherwise /// may-alias pointers. - unsigned getNumRuntimePointerChecks() const { return NumComparisons; } + unsigned getNumRuntimePointerChecks( + const SmallVectorImpl *PtrPartition = nullptr) const { + return PtrRtCheck.getNumberOfChecks(PtrPartition); + } /// Return true if the block BB needs to be predicated in order for the loop /// to be vectorized. @@ -460,10 +467,6 @@ private: /// loop-independent and loop-carried dependences between memory accesses. MemoryDepChecker DepChecker; - /// \brief Number of memchecks required to prove independence of otherwise - /// may-alias pointers - unsigned NumComparisons; - Loop *TheLoop; ScalarEvolution *SE; const DataLayout &DL; diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 18ace6b..c661c7b 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -177,15 +177,21 @@ void LoopAccessInfo::RuntimePointerCheck::print( } } -bool LoopAccessInfo::RuntimePointerCheck::needsAnyChecking( +unsigned LoopAccessInfo::RuntimePointerCheck::getNumberOfChecks( const SmallVectorImpl *PtrPartition) const { unsigned NumPointers = Pointers.size(); + unsigned CheckCount = 0; for (unsigned I = 0; I < NumPointers; ++I) for (unsigned J = I + 1; J < NumPointers; ++J) if (needsChecking(I, J, PtrPartition)) - return true; - return false; + CheckCount++; + return CheckCount; +} + +bool LoopAccessInfo::RuntimePointerCheck::needsAnyChecking( + const SmallVectorImpl *PtrPartition) const { + return getNumberOfChecks(PtrPartition) != 0; } namespace { @@ -220,10 +226,11 @@ public: } /// \brief Check whether we can check the pointers at runtime for - /// non-intersection. + /// non-intersection. Returns true when we have 0 pointers + /// (a check on 0 pointers for non-intersection will always return true). bool canCheckPtrAtRT(LoopAccessInfo::RuntimePointerCheck &RtCheck, - unsigned &NumComparisons, ScalarEvolution *SE, - Loop *TheLoop, const ValueToValueMap &Strides, + bool &NeedRTCheck, ScalarEvolution *SE, Loop *TheLoop, + const ValueToValueMap &Strides, bool ShouldCheckStride = false); /// \brief Goes over all memory accesses, checks whether a RT check is needed @@ -290,23 +297,22 @@ static bool hasComputableBounds(ScalarEvolution *SE, } bool AccessAnalysis::canCheckPtrAtRT( - LoopAccessInfo::RuntimePointerCheck &RtCheck, unsigned &NumComparisons, + LoopAccessInfo::RuntimePointerCheck &RtCheck, bool &NeedRTCheck, ScalarEvolution *SE, Loop *TheLoop, const ValueToValueMap &StridesMap, bool ShouldCheckStride) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. bool CanDoRT = true; + NeedRTCheck = false; + if (!IsRTCheckNeeded) return true; + bool IsDepCheckNeeded = isDependencyCheckNeeded(); - NumComparisons = 0; // We assign a consecutive id to access from different alias sets. // Accesses between different groups doesn't need to be checked. unsigned ASId = 1; for (auto &AS : AST) { - unsigned NumReadPtrChecks = 0; - unsigned NumWritePtrChecks = 0; - // We assign consecutive id to access from different dependence sets. // Accesses within the same set don't need a runtime check. unsigned RunningDepId = 1; @@ -317,11 +323,6 @@ bool AccessAnalysis::canCheckPtrAtRT( bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true)); MemAccessInfo Access(Ptr, IsWrite); - if (IsWrite) - ++NumWritePtrChecks; - else - ++NumReadPtrChecks; - if (hasComputableBounds(SE, StridesMap, Ptr) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. @@ -349,16 +350,15 @@ bool AccessAnalysis::canCheckPtrAtRT( } } - if (IsDepCheckNeeded && CanDoRT && RunningDepId == 2) - NumComparisons += 0; // Only one dependence set. - else { - NumComparisons += (NumWritePtrChecks * (NumReadPtrChecks + - NumWritePtrChecks - 1)); - } - ++ASId; } + // We need a runtime check if there are any accesses that need checking. + // However, some accesses cannot be checked (for example because we + // can't determine their bounds). In these cases we would need a check + // but wouldn't be able to add it. + NeedRTCheck = !CanDoRT || RtCheck.needsAnyChecking(nullptr); + // If the pointers that we would use for the bounds comparison have different // address spaces, assume the values aren't directly comparable, so we can't // use them for the runtime check. We also have to assume they could @@ -1207,22 +1207,17 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { // Build dependence sets and check whether we need a runtime pointer bounds // check. Accesses.buildDependenceSets(); - bool NeedRTCheck = Accesses.isRTCheckNeeded(); // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. - bool CanDoRT = false; - if (NeedRTCheck) - CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, TheLoop, - Strides); - - DEBUG(dbgs() << "LAA: We need to do " << NumComparisons << - " pointer comparisons.\n"); + bool NeedRTCheck; + bool CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, + NeedRTCheck, SE, + TheLoop, Strides); - // If we only have one set of dependences to check pointers among we don't - // need a runtime check. - if (NumComparisons == 0 && NeedRTCheck) - NeedRTCheck = false; + DEBUG(dbgs() << "LAA: We need to do " + << PtrRtCheck.getNumberOfChecks(nullptr) + << " pointer comparisons.\n"); // Check that we found the bounds for the pointer. if (CanDoRT) @@ -1255,10 +1250,11 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { PtrRtCheck.reset(); PtrRtCheck.Need = true; - CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, + CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NeedRTCheck, SE, TheLoop, Strides, true); + // Check that we found the bounds for the pointer. - if (!CanDoRT && NumComparisons > 0) { + if (NeedRTCheck && !CanDoRT) { emitAnalysis(LoopAccessReport() << "cannot check memory dependencies at runtime"); DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n"); @@ -1403,7 +1399,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, const TargetLibraryInfo *TLI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, const ValueToValueMap &Strides) - : DepChecker(SE, L), NumComparisons(0), TheLoop(L), SE(SE), DL(DL), + : DepChecker(SE, L), TheLoop(L), SE(SE), DL(DL), TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1U), CanVecMem(false), StoreToLoopInvariantAddress(false) { diff --git a/llvm/test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll b/llvm/test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll new file mode 100644 index 0000000..f9871c6 --- /dev/null +++ b/llvm/test/Analysis/LoopAccessAnalysis/number-of-memchecks.ll @@ -0,0 +1,58 @@ +; RUN: opt -loop-accesses -analyze < %s | FileCheck %s + +; 3 reads and 3 writes should need 12 memchecks + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnueabi" + +; CHECK: Memory dependences are safe with run-time checks +; Memory dependecies have labels starting from 0, so in +; order to verify that we have n checks, we look for +; (n-1): and not n:. + +; CHECK: Run-time memory checks: +; CHECK-NEXT: 0: +; CHECK: 11: +; CHECK-NOT: 12: + +define void @testf(i16* %a, + i16* %b, + i16* %c, + i16* %d, + i16* %e, + i16* %f) { +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %ind = phi i64 [ 0, %entry ], [ %add, %for.body ] + + %add = add nuw nsw i64 %ind, 1 + + %arrayidxA = getelementptr inbounds i16, i16* %a, i64 %ind + %loadA = load i16, i16* %arrayidxA, align 2 + + %arrayidxB = getelementptr inbounds i16, i16* %b, i64 %ind + %loadB = load i16, i16* %arrayidxB, align 2 + + %arrayidxC = getelementptr inbounds i16, i16* %c, i64 %ind + %loadC = load i16, i16* %arrayidxC, align 2 + + %mul = mul i16 %loadB, %loadA + %mul1 = mul i16 %mul, %loadC + + %arrayidxD = getelementptr inbounds i16, i16* %d, i64 %ind + store i16 %mul1, i16* %arrayidxD, align 2 + + %arrayidxE = getelementptr inbounds i16, i16* %e, i64 %ind + store i16 %mul, i16* %arrayidxE, align 2 + + %arrayidxF = getelementptr inbounds i16, i16* %f, i64 %ind + store i16 %mul1, i16* %arrayidxF, align 2 + + %exitcond = icmp eq i64 %add, 20 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} -- 2.7.4