[LV/LAA] Avoid specializing a loop for stride=1 when this predicate implies a
authorDorit Nuzman <dorit.nuzman@intel.com>
Sun, 5 Nov 2017 16:53:15 +0000 (16:53 +0000)
committerDorit Nuzman <dorit.nuzman@intel.com>
Sun, 5 Nov 2017 16:53:15 +0000 (16:53 +0000)
single-iteration loop

This fixes PR34681. Avoid adding the "Stride == 1" predicate when we know that
Stride >= Trip-Count. Such a predicate will effectively optimize a single
or zero iteration loop, as Trip-Count <= Stride == 1.

Differential Revision: https://reviews.llvm.org/D38785

llvm-svn: 317438

llvm/lib/Analysis/LoopAccessAnalysis.cpp
llvm/test/Transforms/LoopVectorize/pr34681.ll [new file with mode: 0644]
llvm/test/Transforms/LoopVectorize/version-mem-access.ll

index 1988965..e141d6c 100644 (file)
@@ -2136,8 +2136,51 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   if (!Stride)
     return;
 
-  DEBUG(dbgs() << "LAA: Found a strided access that we can version");
+  DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
+                  "versioning:");
   DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *Stride << "\n");
+
+  // Avoid adding the "Stride == 1" predicate when we know that 
+  // Stride >= Trip-Count. Such a predicate will effectively optimize a single
+  // or zero iteration loop, as Trip-Count <= Stride == 1.
+  // 
+  // TODO: We are currently not making a very informed decision on when it is
+  // beneficial to apply stride versioning. It might make more sense that the
+  // users of this analysis (such as the vectorizer) will trigger it, based on 
+  // their specific cost considerations; For example, in cases where stride 
+  // versioning does  not help resolving memory accesses/dependences, the
+  // vectorizer should evaluate the cost of the runtime test, and the benefit 
+  // of various possible stride specializations, considering the alternatives 
+  // of using gather/scatters (if available). 
+  
+  const SCEV *StrideExpr = PSE->getSCEV(Stride);
+  const SCEV *BETakenCount = PSE->getBackedgeTakenCount();  
+
+  // Match the types so we can compare the stride and the BETakenCount.
+  // The Stride can be positive/negative, so we sign extend Stride; 
+  // The backdgeTakenCount is non-negative, so we zero extend BETakenCount.
+  const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout();
+  uint64_t StrideTypeSize = DL.getTypeAllocSize(StrideExpr->getType());
+  uint64_t BETypeSize = DL.getTypeAllocSize(BETakenCount->getType());
+  const SCEV *CastedStride = StrideExpr;
+  const SCEV *CastedBECount = BETakenCount;
+  ScalarEvolution *SE = PSE->getSE();
+  if (BETypeSize >= StrideTypeSize)
+    CastedStride = SE->getNoopOrSignExtend(StrideExpr, BETakenCount->getType());
+  else
+    CastedBECount = SE->getZeroExtendExpr(BETakenCount, StrideExpr->getType());
+  const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
+  // Since TripCount == BackEdgeTakenCount + 1, checking:
+  // "Stride >= TripCount" is equivalent to checking: 
+  // Stride - BETakenCount > 0
+  if (SE->isKnownPositive(StrideMinusBETaken)) {
+    DEBUG(dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
+                    "Stride==1 predicate will imply that the loop executes "
+                    "at most once.\n");
+    return;
+  }  
+  DEBUG(dbgs() << "LAA: Found a strided access that we can version.");
+
   SymbolicStrides[Ptr] = Stride;
   StrideSet.insert(Stride);
 }
diff --git a/llvm/test/Transforms/LoopVectorize/pr34681.ll b/llvm/test/Transforms/LoopVectorize/pr34681.ll
new file mode 100644 (file)
index 0000000..e93265e
--- /dev/null
@@ -0,0 +1,122 @@
+; RUN: opt -S -loop-vectorize -force-vector-width=4 -force-vector-interleave=1 < %s | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+
+; Check the scenario where we have an unknown Stride, which happens to also be
+; the loop iteration count, so if we specialize the loop for the Stride==1 case,
+; this also implies that the loop will iterate no more than a single iteration,
+; as in the following example: 
+;
+;       unsigned int N;
+;       int tmp = 0;
+;       for(unsigned int k=0;k<N;k++) {
+;         tmp+=(int)B[k*N+j];
+;       }
+;
+; We check here that the following runtime scev guard for Stride==1 is NOT generated:
+; vector.scevcheck:
+;   %ident.check = icmp ne i32 %N, 1
+;   %0 = or i1 false, %ident.check
+;   br i1 %0, label %scalar.ph, label %vector.ph
+; Instead the loop is vectorized with an unknown stride.
+
+; CHECK-LABEL: @foo1
+; CHECK: for.body.lr.ph
+; CHECK-NOT: %ident.check = icmp ne i32 %N, 1
+; CHECK-NOT: %[[TEST:[0-9]+]] = or i1 false, %ident.check
+; CHECK-NOT: br i1 %[[TEST]], label %scalar.ph, label %vector.ph
+; CHECK: vector.ph
+; CHECK: vector.body
+; CHECK: <4 x i32>
+; CHECK: middle.block
+; CHECK: scalar.ph
+
+
+define i32 @foo1(i32 %N, i16* nocapture readnone %A, i16* nocapture readonly %B, i32 %i, i32 %j)  {
+entry:
+  %cmp8 = icmp eq i32 %N, 0
+  br i1 %cmp8, label %for.end, label %for.body.lr.ph
+
+for.body.lr.ph:
+  br label %for.body
+
+for.body:
+  %tmp.010 = phi i32 [ 0, %for.body.lr.ph ], [ %add1, %for.body ]
+  %k.09 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %for.body ]
+  %mul = mul i32 %k.09, %N
+  %add = add i32 %mul, %j
+  %arrayidx = getelementptr inbounds i16, i16* %B, i32 %add
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv = sext i16 %0 to i32
+  %add1 = add nsw i32 %tmp.010, %conv
+  %inc = add nuw i32 %k.09, 1
+  %exitcond = icmp eq i32 %inc, %N
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+  %add1.lcssa = phi i32 [ %add1, %for.body ]
+  br label %for.end
+
+for.end: 
+  %tmp.0.lcssa = phi i32 [ 0, %entry ], [ %add1.lcssa, %for.end.loopexit ]
+  ret i32 %tmp.0.lcssa
+}
+
+
+; Check the same, but also where the Stride and the loop iteration count
+; are not of the same data type. 
+;
+;       unsigned short N;
+;       int tmp = 0;
+;       for(unsigned int k=0;k<N;k++) {
+;         tmp+=(int)B[k*N+j];
+;       }
+;
+; We check here that the following runtime scev guard for Stride==1 is NOT generated:
+; vector.scevcheck:
+; %ident.check = icmp ne i16 %N, 1
+; %0 = or i1 false, %ident.check
+; br i1 %0, label %scalar.ph, label %vector.ph
+
+
+; CHECK-LABEL: @foo2
+; CHECK: for.body.lr.ph
+; CHECK-NOT: %ident.check = icmp ne i16 %N, 1
+; CHECK-NOT: %[[TEST:[0-9]+]] = or i1 false, %ident.check
+; CHECK-NOT: br i1 %[[TEST]], label %scalar.ph, label %vector.ph
+; CHECK: vector.ph
+; CHECK: vector.body
+; CHECK: <4 x i32>
+; CHECK: middle.block
+; CHECK: scalar.ph
+
+define i32 @foo2(i16 zeroext %N, i16* nocapture readnone %A, i16* nocapture readonly %B, i32 %i, i32 %j) {
+entry:
+  %conv = zext i16 %N to i32
+  %cmp11 = icmp eq i16 %N, 0
+  br i1 %cmp11, label %for.end, label %for.body.lr.ph
+
+for.body.lr.ph:
+  br label %for.body
+
+for.body:
+  %tmp.013 = phi i32 [ 0, %for.body.lr.ph ], [ %add4, %for.body ]
+  %k.012 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %for.body ]
+  %mul = mul nuw i32 %k.012, %conv
+  %add = add i32 %mul, %j
+  %arrayidx = getelementptr inbounds i16, i16* %B, i32 %add
+  %0 = load i16, i16* %arrayidx, align 2
+  %conv3 = sext i16 %0 to i32
+  %add4 = add nsw i32 %tmp.013, %conv3
+  %inc = add nuw nsw i32 %k.012, 1
+  %exitcond = icmp eq i32 %inc, %conv
+  br i1 %exitcond, label %for.end.loopexit, label %for.body
+
+for.end.loopexit:
+  %add4.lcssa = phi i32 [ %add4, %for.body ]
+  br label %for.end
+
+for.end:
+  %tmp.0.lcssa = phi i32 [ 0, %entry ], [ %add4.lcssa, %for.end.loopexit ]
+  ret i32 %tmp.0.lcssa
+}
index a9d319e..774b6f2 100644 (file)
@@ -65,7 +65,8 @@ for.end:
 define void @fn1(double* noalias %x, double* noalias %c, double %a) {
 entry:
   %conv = fptosi double %a to i32
-  %cmp8 = icmp sgt i32 %conv, 0
+  %conv2 = add i32 %conv, 4
+  %cmp8 = icmp sgt i32 %conv2, 0
   br i1 %cmp8, label %for.body.preheader, label %for.end
 
 for.body.preheader:
@@ -82,7 +83,7 @@ for.body:
   store double %1, double* %arrayidx3, align 8
   %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
   %lftr.wideiv = trunc i64 %indvars.iv.next to i32
-  %exitcond = icmp eq i32 %lftr.wideiv, %conv
+  %exitcond = icmp eq i32 %lftr.wideiv, %conv2
   br i1 %exitcond, label %for.end.loopexit, label %for.body
 
 for.end.loopexit: