[LFTR] Fix post-inc pointer IV with truncated exit count (PR41998)
authorNikita Popov <nikita.ppv@gmail.com>
Sat, 29 Jun 2019 09:24:12 +0000 (09:24 +0000)
committerNikita Popov <nikita.ppv@gmail.com>
Sat, 29 Jun 2019 09:24:12 +0000 (09:24 +0000)
Fixes https://bugs.llvm.org/show_bug.cgi?id=41998. Usually when we
have a truncated exit count we'll truncate the IV when comparing
against the limit, in which case exit count overflow in post-inc
form doesn't matter. However, for pointer IVs we don't do that, so
we have to be careful about incrementing the IV in the wide type.

I'm fixing this by removing the IVCount variable (which was
ExitCount or ExitCount+1) and replacing it with a UsePostInc flag,
and then moving the actual limit adjustment to the individual cases
(which are: pointer IV where we add to the wide type, integer IV
where we add to the narrow type, and constant integer IV where we
add to the wide type).

Differential Revision: https://reviews.llvm.org/D63686

llvm-svn: 364709

llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
llvm/test/Transforms/IndVarSimplify/2011-11-01-lftrptr.ll
llvm/test/Transforms/IndVarSimplify/lftr-pr41998.ll

index f3352a3..8b3707d 100644 (file)
@@ -2300,27 +2300,30 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,
 
 /// Insert an IR expression which computes the value held by the IV IndVar
 /// (which must be an loop counter w/unit stride) after the backedge of loop L
-/// is taken IVCount times.  
+/// is taken ExitCount times.
 static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB,
-                           const SCEV *IVCount, Loop *L,
+                           const SCEV *ExitCount, bool UsePostInc, Loop *L,
                            SCEVExpander &Rewriter, ScalarEvolution *SE) {
   assert(isLoopCounter(IndVar, L, SE));
   const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar));
   const SCEV *IVInit = AR->getStart();
 
-  // IVInit may be a pointer while IVCount is an integer when FindLoopCounter
-  // finds a valid pointer IV. Sign extend BECount in order to materialize a
+  // IVInit may be a pointer while ExitCount is an integer when FindLoopCounter
+  // finds a valid pointer IV. Sign extend ExitCount in order to materialize a
   // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing
   // the existing GEPs whenever possible.
-  if (IndVar->getType()->isPointerTy() && !IVCount->getType()->isPointerTy()) {
+  if (IndVar->getType()->isPointerTy() &&
+      !ExitCount->getType()->isPointerTy()) {
     // IVOffset will be the new GEP offset that is interpreted by GEP as a
-    // signed value. IVCount on the other hand represents the loop trip count,
+    // signed value. ExitCount on the other hand represents the loop trip count,
     // which is an unsigned value. FindLoopCounter only allows induction
     // variables that have a positive unit stride of one. This means we don't
     // have to handle the case of negative offsets (yet) and just need to zero
-    // extend IVCount.
+    // extend ExitCount.
     Type *OfsTy = SE->getEffectiveSCEVType(IVInit->getType());
-    const SCEV *IVOffset = SE->getTruncateOrZeroExtend(IVCount, OfsTy);
+    const SCEV *IVOffset = SE->getTruncateOrZeroExtend(ExitCount, OfsTy);
+    if (UsePostInc)
+      IVOffset = SE->getAddExpr(IVOffset, SE->getOne(OfsTy));
 
     // Expand the code for the iteration count.
     assert(SE->isLoopInvariant(IVOffset, L) &&
@@ -2341,7 +2344,7 @@ static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB,
     return Builder.CreateGEP(GEPBase->getType()->getPointerElementType(),
                              GEPBase, GEPOffset, "lftr.limit");
   } else {
-    // In any other case, convert both IVInit and IVCount to integers before
+    // In any other case, convert both IVInit and ExitCount to integers before
     // comparing. This may result in SCEV expansion of pointers, but in practice
     // SCEV will fold the pointer arithmetic away as such:
     // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc).
@@ -2349,20 +2352,24 @@ static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB,
     // Valid Cases: (1) both integers is most common; (2) both may be pointers
     // for simple memset-style loops.
     //
-    // IVInit integer and IVCount pointer would only occur if a canonical IV
+    // IVInit integer and ExitCount pointer would only occur if a canonical IV
     // were generated on top of case #2, which is not expected.
 
     assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride");
-    // For unit stride, IVCount = Start + BECount with 2's complement overflow.
+    // For unit stride, IVCount = Start + ExitCount with 2's complement
+    // overflow.
     const SCEV *IVInit = AR->getStart();
 
     // For integer IVs, truncate the IV before computing IVInit + BECount.
     if (SE->getTypeSizeInBits(IVInit->getType())
-        > SE->getTypeSizeInBits(IVCount->getType()))
-      IVInit = SE->getTruncateExpr(IVInit, IVCount->getType());
+        > SE->getTypeSizeInBits(ExitCount->getType()))
+      IVInit = SE->getTruncateExpr(IVInit, ExitCount->getType());
+
+    const SCEV *IVLimit = SE->getAddExpr(IVInit, ExitCount);
+
+    if (UsePostInc)
+      IVLimit = SE->getAddExpr(IVLimit, SE->getOne(IVLimit->getType()));
 
-    const SCEV *IVLimit = SE->getAddExpr(IVInit, IVCount);
-    
     // Expand the code for the iteration count.
     BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());
     IRBuilder<> Builder(BI);
@@ -2371,8 +2378,8 @@ static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB,
     // Ensure that we generate the same type as IndVar, or a smaller integer
     // type. In the presence of null pointer values, we have an integer type
     // SCEV expression (IVInit) for a pointer type IV value (IndVar).
-    Type *LimitTy = IVCount->getType()->isPointerTy() ?
-      IndVar->getType() : IVCount->getType();
+    Type *LimitTy = ExitCount->getType()->isPointerTy() ?
+      IndVar->getType() : ExitCount->getType();
     return Rewriter.expandCodeFor(IVLimit, LimitTy, BI);
   }
 }
@@ -2391,9 +2398,9 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
   Instruction * const IncVar =
     cast<Instruction>(IndVar->getIncomingValueForBlock(L->getLoopLatch()));
 
-  // Initialize CmpIndVar and IVCount to their preincremented values.
+  // Initialize CmpIndVar to the preincremented IV.
   Value *CmpIndVar = IndVar;
-  const SCEV *IVCount = ExitCount;
+  bool UsePostInc = false;
 
   // If the exiting block is the same as the backedge block, we prefer to
   // compare against the post-incremented value, otherwise we must compare
@@ -2412,15 +2419,9 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
         SafeToPostInc =
           mustExecuteUBIfPoisonOnPathTo(IncVar, ExitingBB->getTerminator(), DT);
     }
+
     if (SafeToPostInc) {
-      // Add one to the "backedge-taken" count to get the trip count.
-      // This addition may overflow, which is valid as long as the comparison
-      // is truncated to ExitCount->getType().
-      IVCount = SE->getAddExpr(ExitCount,
-                               SE->getOne(ExitCount->getType()));
-      // The BackedgeTaken expression contains the number of times that the
-      // backedge branches to the loop header.  This is one less than the
-      // number of times the loop executes, so use the incremented indvar.
+      UsePostInc = true;
       CmpIndVar = IncVar;
     }
   }
@@ -2445,7 +2446,8 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
       BO->setHasNoSignedWrap(AR->hasNoSignedWrap());
   }
 
-  Value *ExitCnt = genLoopLimit(IndVar, ExitingBB, IVCount, L, Rewriter, SE);
+  Value *ExitCnt = genLoopLimit(
+      IndVar, ExitingBB, ExitCount, UsePostInc, L, Rewriter, SE);
   assert(ExitCnt->getType()->isPointerTy() ==
              IndVar->getType()->isPointerTy() &&
          "genLoopLimit missed a cast");
@@ -2466,25 +2468,20 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
     Builder.SetCurrentDebugLocation(Cond->getDebugLoc());
 
   // LFTR can ignore IV overflow and truncate to the width of
-  // BECount. This avoids materializing the add(zext(add)) expression.
+  // ExitCount. This avoids materializing the add(zext(add)) expression.
   unsigned CmpIndVarSize = SE->getTypeSizeInBits(CmpIndVar->getType());
   unsigned ExitCntSize = SE->getTypeSizeInBits(ExitCnt->getType());
   if (CmpIndVarSize > ExitCntSize) {
     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar));
     const SCEV *ARStart = AR->getStart();
     const SCEV *ARStep = AR->getStepRecurrence(*SE);
-    // For constant IVCount, avoid truncation.
-    if (isa<SCEVConstant>(ARStart) && isa<SCEVConstant>(IVCount)) {
+    // For constant ExitCount, avoid truncation.
+    if (isa<SCEVConstant>(ARStart) && isa<SCEVConstant>(ExitCount)) {
       const APInt &Start = cast<SCEVConstant>(ARStart)->getAPInt();
-      APInt Count = cast<SCEVConstant>(IVCount)->getAPInt();
-      // Note that the post-inc value of ExitCount may have overflowed
-      // above such that IVCount is now zero.
-      if (IVCount != ExitCount && Count == 0) {
-        Count = APInt::getMaxValue(Count.getBitWidth()).zext(CmpIndVarSize);
+      APInt Count = cast<SCEVConstant>(ExitCount)->getAPInt();
+      Count = Count.zext(CmpIndVarSize);
+      if (UsePostInc)
         ++Count;
-      }
-      else
-        Count = Count.zext(CmpIndVarSize);
       APInt NewLimit;
       if (cast<SCEVConstant>(ARStep)->getValue()->isNegative())
         NewLimit = Start - Count;
@@ -2529,7 +2526,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,
                     << "       op:\t" << (P == ICmpInst::ICMP_NE ? "!=" : "==")
                     << "\n"
                     << "      RHS:\t" << *ExitCnt << "\n"
-                    << "  IVCount:\t" << *IVCount << "\n"
+                    << "ExitCount:\t" << *ExitCount << "\n"
                     << "  was: " << *BI->getCondition() << "\n");
 
   Value *Cond = Builder.CreateICmp(P, CmpIndVar, ExitCnt, "exitcond");
index 49b2437..48c8925 100644 (file)
@@ -146,8 +146,11 @@ define i8 @testnullptrint(i8* %buf, i8* %end) nounwind {
 ; PTR64-NEXT:    [[GUARD:%.*]] = icmp ult i32 0, [[CNT]]
 ; PTR64-NEXT:    br i1 [[GUARD]], label [[PREHEADER:%.*]], label [[EXIT:%.*]]
 ; PTR64:       preheader:
-; PTR64-NEXT:    [[TMP1:%.*]] = zext i32 [[CNT]] to i64
-; PTR64-NEXT:    [[LFTR_LIMIT:%.*]] = getelementptr i8, i8* null, i64 [[TMP1]]
+; PTR64-NEXT:    [[TMP1:%.*]] = add i32 [[EI]], -1
+; PTR64-NEXT:    [[TMP2:%.*]] = sub i32 [[TMP1]], [[BI]]
+; PTR64-NEXT:    [[TMP3:%.*]] = zext i32 [[TMP2]] to i64
+; PTR64-NEXT:    [[TMP4:%.*]] = add nuw nsw i64 [[TMP3]], 1
+; PTR64-NEXT:    [[LFTR_LIMIT:%.*]] = getelementptr i8, i8* null, i64 [[TMP4]]
 ; PTR64-NEXT:    br label [[LOOP:%.*]]
 ; PTR64:       loop:
 ; PTR64-NEXT:    [[P_01_US_US:%.*]] = phi i8* [ null, [[PREHEADER]] ], [ [[GEP:%.*]], [[LOOP]] ]
index 1f04a4b..6ec5a2d 100644 (file)
@@ -42,9 +42,10 @@ define void @test_ptr(i32 %start) {
 ; CHECK-LABEL: @test_ptr(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = trunc i32 [[START:%.*]] to i3
-; CHECK-NEXT:    [[TMP1:%.*]] = sub i3 0, [[TMP0]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sub i3 -1, [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = zext i3 [[TMP1]] to i64
-; CHECK-NEXT:    [[LFTR_LIMIT:%.*]] = getelementptr i8, i8* getelementptr inbounds ([256 x i8], [256 x i8]* @data, i64 0, i64 0), i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP3:%.*]] = add nuw nsw i64 [[TMP2]], 1
+; CHECK-NEXT:    [[LFTR_LIMIT:%.*]] = getelementptr i8, i8* getelementptr inbounds ([256 x i8], [256 x i8]* @data, i64 0, i64 0), i64 [[TMP3]]
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
 ; CHECK-NEXT:    [[P:%.*]] = phi i8* [ getelementptr inbounds ([256 x i8], [256 x i8]* @data, i64 0, i64 0), [[ENTRY:%.*]] ], [ [[P_INC:%.*]], [[LOOP]] ]