[LAA/LV] Simplify stride speculation logic [NFC] (try 2)
authorPhilip Reames <preames@rivosinc.com>
Thu, 11 May 2023 16:47:37 +0000 (09:47 -0700)
committerPhilip Reames <listmail@philipreames.com>
Thu, 11 May 2023 17:19:23 +0000 (10:19 -0700)
The original commit wasn't quite NFC, and this was caught by an arguably overly strong assert.  Specifically, I'd failed to strip off the integer cast off the SCEV before saving it in the map.  The result - other than a failed assert - is that we'd speculate on the casted unknown, not the unknown.  The only case I can think of where that might change behavior would be a sext(i1 load).  I doubt that case is interesting in practice, but it's good to be strictly NFC on this change regardless.

Original commit message follows..

The existing code makes it hard to tell that collectStridedAccess is really about identifying some loop invariant SCEV which is *profitable* to speculate is equal to one. The odd dual usage structure of Value and SCEV confuses this point.

We could choose to loosen the profitability analysis if desired. I'm not proposing doing so at this time as it exposes too many cases where the speculation is unprofitable.

Differential Revision: https://reviews.llvm.org/D147750

llvm/include/llvm/Analysis/LoopAccessAnalysis.h
llvm/include/llvm/Analysis/VectorUtils.h
llvm/lib/Analysis/LoopAccessAnalysis.cpp
llvm/lib/Analysis/VectorUtils.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

index fae22d4..11b4d62 100644 (file)
@@ -184,7 +184,7 @@ public:
   ///
   /// Only checks sets with elements in \p CheckDeps.
   bool areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
-                   const ValueToValueMap &Strides);
+                   const DenseMap<Value *, const SCEV *> &Strides);
 
   /// No memory dependence was encountered that would inhibit
   /// vectorization.
@@ -316,7 +316,7 @@ private:
   /// Otherwise, this function returns true signaling a possible dependence.
   Dependence::DepType isDependent(const MemAccessInfo &A, unsigned AIdx,
                                   const MemAccessInfo &B, unsigned BIdx,
-                                  const ValueToValueMap &Strides);
+                                  const DenseMap<Value *, const SCEV *> &Strides);
 
   /// Check whether the data dependence could prevent store-load
   /// forwarding.
@@ -612,7 +612,9 @@ public:
 
   /// If an access has a symbolic strides, this maps the pointer value to
   /// the stride symbol.
-  const ValueToValueMap &getSymbolicStrides() const { return SymbolicStrides; }
+  const DenseMap<Value *, const SCEV *> &getSymbolicStrides() const {
+    return SymbolicStrides;
+  }
 
   /// Pointer has a symbolic stride.
   bool hasStride(Value *V) const { return StrideSet.count(V); }
@@ -699,7 +701,7 @@ private:
 
   /// If an access has a symbolic strides, this maps the pointer value to
   /// the stride symbol.
-  ValueToValueMap SymbolicStrides;
+  DenseMap<Value *, const SCEV *> SymbolicStrides;
 
   /// Set of symbolic strides values.
   SmallPtrSet<Value *, 8> StrideSet;
@@ -716,9 +718,10 @@ Value *stripIntegerCast(Value *V);
 ///
 /// \p PtrToStride provides the mapping between the pointer value and its
 /// stride as collected by LoopVectorizationLegality::collectStridedAccess.
-const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
-                                      const ValueToValueMap &PtrToStride,
-                                      Value *Ptr);
+const SCEV *
+replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
+                          const DenseMap<Value *, const SCEV *> &PtrToStride,
+                          Value *Ptr);
 
 /// If the pointer has a constant stride return it in units of the access type
 /// size.  Otherwise return std::nullopt.
@@ -737,7 +740,7 @@ const SCEV *replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
 std::optional<int64_t>
 getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr,
              const Loop *Lp,
-             const ValueToValueMap &StridesMap = ValueToValueMap(),
+             const DenseMap<Value *, const SCEV *> &StridesMap = DenseMap<Value *, const SCEV *>(),
              bool Assume = false, bool ShouldCheckWrap = true);
 
 /// Returns the distance between the pointers \p PtrA and \p PtrB iff they are
index 5d82454..18214fe 100644 (file)
@@ -917,7 +917,7 @@ private:
   /// Collect all the accesses with a constant stride in program order.
   void collectConstStrideAccesses(
       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
-      const ValueToValueMap &Strides);
+      const DenseMap<Value *, const SCEV *> &Strides);
 
   /// Returns true if \p Stride is allowed in an interleaved group.
   static bool isStrided(int Stride);
index 351e090..df21679 100644 (file)
@@ -154,23 +154,25 @@ Value *llvm::stripIntegerCast(Value *V) {
 }
 
 const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE,
-                                            const ValueToValueMap &PtrToStride,
+                                            const DenseMap<Value *, const SCEV *> &PtrToStride,
                                             Value *Ptr) {
   const SCEV *OrigSCEV = PSE.getSCEV(Ptr);
 
   // If there is an entry in the map return the SCEV of the pointer with the
   // symbolic stride replaced by one.
-  ValueToValueMap::const_iterator SI = PtrToStride.find(Ptr);
+  DenseMap<Value *, const SCEV *>::const_iterator SI = PtrToStride.find(Ptr);
   if (SI == PtrToStride.end())
     // For a non-symbolic stride, just return the original expression.
     return OrigSCEV;
 
-  Value *StrideVal = stripIntegerCast(SI->second);
-
-  ScalarEvolution *SE = PSE.getSE();
-  const SCEV *StrideSCEV = SE->getSCEV(StrideVal);
+  const SCEV *StrideSCEV = SI->second;
+  // Note: This assert is both overly strong and overly weak.  The actual
+  // invariant here is that StrideSCEV should be loop invariant.  The only
+  // such invariant strides we happen to speculate right now are unknowns
+  // and thus this is a reasonable proxy of the actual invariant.
   assert(isa<SCEVUnknown>(StrideSCEV) && "shouldn't be in map");
 
+  ScalarEvolution *SE = PSE.getSE();
   const auto *CT = SE->getOne(StrideSCEV->getType());
   PSE.addPredicate(*SE->getEqualPredicate(StrideSCEV, CT));
   auto *Expr = PSE.getSCEV(Ptr);
@@ -658,7 +660,7 @@ public:
   /// the bounds of the pointer.
   bool createCheckForAccess(RuntimePointerChecking &RtCheck,
                             MemAccessInfo Access, Type *AccessTy,
-                            const ValueToValueMap &Strides,
+                            const DenseMap<Value *, const SCEV *> &Strides,
                             DenseMap<Value *, unsigned> &DepSetId,
                             Loop *TheLoop, unsigned &RunningDepId,
                             unsigned ASId, bool ShouldCheckStride, bool Assume);
@@ -669,7 +671,7 @@ public:
   /// Returns true if we need no check or if we do and we can generate them
   /// (i.e. the pointers have computable bounds).
   bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE,
-                       Loop *TheLoop, const ValueToValueMap &Strides,
+                       Loop *TheLoop, const DenseMap<Value *, const SCEV *> &Strides,
                        Value *&UncomputablePtr, bool ShouldCheckWrap = false);
 
   /// Goes over all memory accesses, checks whether a RT check is needed
@@ -764,7 +766,7 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr,
 
 /// Check whether a pointer address cannot wrap.
 static bool isNoWrap(PredicatedScalarEvolution &PSE,
-                     const ValueToValueMap &Strides, Value *Ptr, Type *AccessTy,
+                     const DenseMap<Value *, const SCEV *> &Strides, Value *Ptr, Type *AccessTy,
                      Loop *L) {
   const SCEV *PtrScev = PSE.getSCEV(Ptr);
   if (PSE.getSE()->isLoopInvariant(PtrScev, L))
@@ -957,7 +959,7 @@ static void findForkedSCEVs(
 
 static SmallVector<PointerIntPair<const SCEV *, 1, bool>>
 findForkedPointer(PredicatedScalarEvolution &PSE,
-                  const ValueToValueMap &StridesMap, Value *Ptr,
+                  const DenseMap<Value *, const SCEV *> &StridesMap, Value *Ptr,
                   const Loop *L) {
   ScalarEvolution *SE = PSE.getSE();
   assert(SE->isSCEVable(Ptr->getType()) && "Value is not SCEVable!");
@@ -982,7 +984,7 @@ findForkedPointer(PredicatedScalarEvolution &PSE,
 
 bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
                                           MemAccessInfo Access, Type *AccessTy,
-                                          const ValueToValueMap &StridesMap,
+                                          const DenseMap<Value *, const SCEV *> &StridesMap,
                                           DenseMap<Value *, unsigned> &DepSetId,
                                           Loop *TheLoop, unsigned &RunningDepId,
                                           unsigned ASId, bool ShouldCheckWrap,
@@ -1043,7 +1045,7 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
 
 bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,
                                      ScalarEvolution *SE, Loop *TheLoop,
-                                     const ValueToValueMap &StridesMap,
+                                     const DenseMap<Value *, const SCEV *> &StridesMap,
                                      Value *&UncomputablePtr, bool ShouldCheckWrap) {
   // Find pointers with computable bounds. We are going to use this information
   // to place a runtime bound check.
@@ -1373,7 +1375,7 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR,
 std::optional<int64_t> llvm::getPtrStride(PredicatedScalarEvolution &PSE,
                                           Type *AccessTy, Value *Ptr,
                                           const Loop *Lp,
-                                          const ValueToValueMap &StridesMap,
+                                          const DenseMap<Value *, const SCEV *> &StridesMap,
                                           bool Assume, bool ShouldCheckWrap) {
   Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Unexpected non-ptr");
@@ -1822,7 +1824,7 @@ static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride,
 MemoryDepChecker::Dependence::DepType
 MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
                               const MemAccessInfo &B, unsigned BIdx,
-                              const ValueToValueMap &Strides) {
+                              const DenseMap<Value *, const SCEV *> &Strides) {
   assert (AIdx < BIdx && "Must pass arguments in program order");
 
   auto [APtr, AIsWrite] = A;
@@ -2016,7 +2018,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
 
 bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets,
                                    MemAccessInfoList &CheckDeps,
-                                   const ValueToValueMap &Strides) {
+                                   const DenseMap<Value *, const SCEV *> &Strides) {
 
   MaxSafeDepDistBytes = -1;
   SmallPtrSet<MemAccessInfo, 8> Visited;
@@ -2691,6 +2693,12 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   if (!Ptr)
     return;
 
+  // Note: getStrideFromPointer is a *profitability* heuristic.  We
+  // could broaden the scope of values returned here - to anything
+  // which happens to be loop invariant and contributes to the
+  // computation of an interesting IV - but we chose not to as we
+  // don't have a cost model here, and broadening the scope exposes
+  // far too many unprofitable cases.
   Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
   if (!Stride)
     return;
@@ -2746,7 +2754,10 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   }
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
 
-  SymbolicStrides[Ptr] = Stride;
+  // Strip back off the integer cast, and check that our result is a
+  // SCEVUnknown as we expect.
+  Value *StrideVal = stripIntegerCast(Stride);
+  SymbolicStrides[Ptr] = cast<SCEVUnknown>(PSE->getSCEV(StrideVal));
   StrideSet.insert(Stride);
 }
 
index 423ef67..5da21e7 100644 (file)
@@ -1011,7 +1011,7 @@ bool InterleavedAccessInfo::isStrided(int Stride) {
 
 void InterleavedAccessInfo::collectConstStrideAccesses(
     MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
-    const ValueToValueMap &Strides) {
+    const DenseMap<Value*, const SCEV*> &Strides) {
   auto &DL = TheLoop->getHeader()->getModule()->getDataLayout();
 
   // Since it's desired that the load/store instructions be maintained in
@@ -1091,7 +1091,7 @@ void InterleavedAccessInfo::collectConstStrideAccesses(
 void InterleavedAccessInfo::analyzeInterleaving(
                                  bool EnablePredicatedInterleavedMemAccesses) {
   LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n");
-  const ValueToValueMap &Strides = LAI->getSymbolicStrides();
+  const auto &Strides = LAI->getSymbolicStrides();
 
   // Holds all accesses with a constant stride.
   MapVector<Instruction *, StrideDescriptor> AccessStrideInfo;
index 80a3f13..ffefb94 100644 (file)
@@ -3477,7 +3477,7 @@ InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
 
 static bool containsDecreasingPointers(Loop *TheLoop,
                                        PredicatedScalarEvolution *PSE) {
-  const ValueToValueMap &Strides = ValueToValueMap();
+  const auto &Strides = DenseMap<Value *, const SCEV *>();
   for (BasicBlock *BB : TheLoop->blocks()) {
     // Scan the instructions in the block and look for addresses that are
     // consecutive and decreasing.
index 3a868e8..a2b5c04 100644 (file)
@@ -456,8 +456,8 @@ int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy,
   // it's collected.  This happens from canVectorizeWithIfConvert, when the
   // pointer is checked to reference consecutive elements suitable for a
   // masked access.
-  const ValueToValueMap &Strides =
-    LAI ? LAI->getSymbolicStrides() : ValueToValueMap();
+  const auto &Strides =
+    LAI ? LAI->getSymbolicStrides() : DenseMap<Value *, const SCEV *>();
 
   Function *F = TheLoop->getHeader()->getParent();
   bool OptForSize = F->hasOptSize() ||