From 46abd1fbe88fe1f4b0e6cb2b87f3e7d148bbadf7 Mon Sep 17 00:00:00 2001 From: Rosie Sumpter Date: Mon, 9 Aug 2021 12:51:17 +0100 Subject: [PATCH] [LoopFlatten] Fix assertion failure in checkOverflow There is an assertion failure in computeOverflowForUnsignedMul (used in checkOverflow) due to the inner and outer trip counts having different types. This occurs when the IV has been widened, but the loop components are not successfully rediscovered. This is fixed by some refactoring of the code in findLoopComponents which identifies the trip count of the loop. --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp | 93 ++++++++++++++++++---------- llvm/test/Transforms/LoopFlatten/widen-iv.ll | 46 ++++++++++++++ 2 files changed, 105 insertions(+), 34 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index 61f6d21..3343bdd 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -93,6 +93,17 @@ struct FlattenInfo { FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; +static bool +setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, + SmallPtrSetImpl &IterationInstructions) { + TripCount = TC; + IterationInstructions.insert(Increment); + LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump()); + LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); + LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); + return true; +} + // Finds the induction variable, increment and trip count for a simple loop that // we can flatten. static bool findLoopComponents( @@ -164,49 +175,63 @@ static bool findLoopComponents( return false; } // The trip count is the RHS of the compare. If this doesn't match the trip - // count computed by SCEV then this is either because the trip count variable - // has been widened (then leave the trip count as it is), or because it is a - // constant and another transformation has changed the compare, e.g. - // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1. - TripCount = Compare->getOperand(1); + // count computed by SCEV then this is because the trip count variable + // has been widened so the types don't match, or because it is a constant and + // another transformation has changed the compare (e.g. icmp ult %inc, + // tripcount -> icmp ult %j, tripcount-1), or both. + Value *RHS = Compare->getOperand(1); const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); if (isa(BackedgeTakenCount)) { LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); return false; } const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount); - if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) { - ConstantInt *RHS = dyn_cast(TripCount); - if (!RHS) { - LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); - return false; - } - // The L->isCanonical check above ensures we only get here if the loop - // increments by 1 on each iteration, so the RHS of the Compare is - // tripcount-1 (i.e equivalent to the backedge taken count). - assert(SE->getSCEV(RHS) == BackedgeTakenCount && - "Expected RHS of compare to be equal to the backedge taken count"); - ConstantInt *One = ConstantInt::get(RHS->getType(), 1); - TripCount = ConstantInt::get(TripCount->getContext(), - RHS->getValue() + One->getValue()); - } else if (SE->getSCEV(TripCount) != SCEVTripCount) { - auto *TripCountInst = dyn_cast(TripCount); - if (!TripCountInst) { - LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); - return false; + const SCEV *SCEVRHS = SE->getSCEV(RHS); + if (SCEVRHS == SCEVTripCount) + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); + ConstantInt *ConstantRHS = dyn_cast(RHS); + if (ConstantRHS) { + const SCEV *BackedgeTCExt = nullptr; + if (IsWidened) { + const SCEV *SCEVTripCountExt; + // Find the extended backedge taken count and extended trip count using + // SCEV. One of these should now match the RHS of the compare. + BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt); + if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } } - if ((!isa(TripCountInst) && !isa(TripCountInst)) || - SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { - LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); - return false; + // If the RHS of the compare is equal to the backedge taken count we need + // to add one to get the trip count. + if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { + ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1); + Value *NewRHS = ConstantInt::get( + ConstantRHS->getContext(), ConstantRHS->getValue() + One->getValue()); + return setLoopComponents(NewRHS, TripCount, Increment, + IterationInstructions); } + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); } - IterationInstructions.insert(Increment); - LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); - LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); - - LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); - return true; + // If the RHS isn't a constant then check that the reason it doesn't match + // the SCEV trip count is because the RHS is a ZExt or SExt instruction + // (and take the trip count to be the RHS). + if (!IsWidened) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + auto *TripCountInst = dyn_cast(RHS); + if (!TripCountInst) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + if ((!isa(TripCountInst) && !isa(TripCountInst)) || + SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); } static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { diff --git a/llvm/test/Transforms/LoopFlatten/widen-iv.ll b/llvm/test/Transforms/LoopFlatten/widen-iv.ll index a6b13e4..abd7013 100644 --- a/llvm/test/Transforms/LoopFlatten/widen-iv.ll +++ b/llvm/test/Transforms/LoopFlatten/widen-iv.ll @@ -525,6 +525,52 @@ for.cond.cleanup: ret void } +; Identify trip count when it is constant and the IV has been widened. +define i32 @constTripCount() { +; CHECK-LABEL: @constTripCount( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 20, 20 +; CHECK-NEXT: br label [[I_LOOP:%.*]] +; CHECK: i.loop: +; CHECK-NEXT: [[INDVAR1:%.*]] = phi i64 [ [[INDVAR_NEXT2:%.*]], [[J_LOOPDONE:%.*]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: br label [[J_LOOP:%.*]] +; CHECK: j.loop: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[I_LOOP]] ] +; CHECK-NEXT: call void @payload() +; CHECK-NEXT: [[INDVAR_NEXT:%.*]] = add i64 [[INDVAR]], 1 +; CHECK-NEXT: [[J_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT]], 20 +; CHECK-NEXT: br label [[J_LOOPDONE]] +; CHECK: j.loopdone: +; CHECK-NEXT: [[INDVAR_NEXT2]] = add i64 [[INDVAR1]], 1 +; CHECK-NEXT: [[I_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT2]], [[FLATTEN_TRIPCOUNT]] +; CHECK-NEXT: br i1 [[I_ATEND]], label [[I_LOOPDONE:%.*]], label [[I_LOOP]] +; CHECK: i.loopdone: +; CHECK-NEXT: ret i32 0 +; +entry: + br label %i.loop + +i.loop: + %i = phi i8 [ 0, %entry ], [ %i.inc, %j.loopdone ] + br label %j.loop + +j.loop: + %j = phi i8 [ 0, %i.loop ], [ %j.inc, %j.loop ] + call void @payload() + %j.inc = add i8 %j, 1 + %j.atend = icmp eq i8 %j.inc, 20 + br i1 %j.atend, label %j.loopdone, label %j.loop + +j.loopdone: + %i.inc = add i8 %i, 1 + %i.atend = icmp eq i8 %i.inc, 20 + br i1 %i.atend, label %i.loopdone, label %i.loop + +i.loopdone: + ret i32 0 +} + +declare void @payload() declare dso_local i32 @use_32(i32) declare dso_local i32 @use_16(i16) declare dso_local i32 @use_64(i64) -- 2.7.4