/// (splitting the exit block as necessary). It simplifies the branch within
/// the loop to an unconditional branch but doesn't remove it entirely. Further
/// cleanup can be done with some simplify-cfg like pass.
+///
+/// If `SE` is not null, it will be updated based on the potential loop SCEVs
+/// invalidated by this.
static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
- LoopInfo &LI) {
+ LoopInfo &LI, ScalarEvolution *SE) {
assert(BI.isConditional() && "Can only unswitch a conditional branch!");
LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n");
}
});
+ // If we have scalar evolutions, we need to invalidate them including this
+ // loop and the loop containing the exit block.
+ if (SE) {
+ if (Loop *ExitL = LI.getLoopFor(LoopExitBB))
+ SE->forgetLoop(ExitL);
+ else
+ // Forget the entire nest as this exits the entire nest.
+ SE->forgetTopmostLoop(&L);
+ }
+
// Split the preheader, so that we know that there is a safe place to insert
// the conditional branch. We will change the preheader to have a conditional
// branch on LoopCond.
/// switch will not be revisited. If after unswitching there is only a single
/// in-loop successor, the switch is further simplified to an unconditional
/// branch. Still more cleanup can be done with some simplify-cfg like pass.
+///
+/// If `SE` is not null, it will be updated based on the potential loop SCEVs
+/// invalidated by this.
static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
- LoopInfo &LI) {
+ LoopInfo &LI, ScalarEvolution *SE) {
LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n");
Value *LoopCond = SI.getCondition();
LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n");
+ // We may need to invalidate SCEVs for the outermost loop reached by any of
+ // the exits.
+ Loop *OuterL = &L;
+
SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases;
ExitCases.reserve(ExitCaseIndices.size());
// We walk the case indices backwards so that we remove the last case first
// and don't disrupt the earlier indices.
for (unsigned Index : reverse(ExitCaseIndices)) {
auto CaseI = SI.case_begin() + Index;
+ // Compute the outer loop from this exit.
+ Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor());
+ if (!ExitL || ExitL->contains(OuterL))
+ OuterL = ExitL;
// Save the value of this case.
ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()});
// Delete the unswitched cases.
SI.removeCase(CaseI);
}
+ if (SE) {
+ if (OuterL)
+ SE->forgetLoop(OuterL);
+ else
+ SE->forgetTopmostLoop(&L);
+ }
+
// Check if after this all of the remaining cases point at the same
// successor.
BasicBlock *CommonSuccBB = nullptr;
///
/// The return value indicates whether anything was unswitched (and therefore
/// changed).
+///
+/// If `SE` is not null, it will be updated based on the potential loop SCEVs
+/// invalidated by this.
static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT,
- LoopInfo &LI) {
+ LoopInfo &LI, ScalarEvolution *SE) {
bool Changed = false;
// If loop header has only one reachable successor we should keep looking for
if (isa<Constant>(SI->getCondition()))
return Changed;
- if (!unswitchTrivialSwitch(L, *SI, DT, LI))
+ if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE))
// Couldn't unswitch this one so we're done.
return Changed;
// Found a trivial condition candidate: non-foldable conditional branch. If
// we fail to unswitch this, we can't do anything else that is trivial.
- if (!unswitchTrivialBranch(L, *BI, DT, LI))
+ if (!unswitchTrivialBranch(L, *BI, DT, LI, SE))
return Changed;
// Mark that we managed to unswitch something.
static bool unswitchNontrivialInvariants(
Loop &L, TerminatorInst &TI, ArrayRef<Value *> Invariants,
DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
- function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
+ function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
+ ScalarEvolution *SE) {
auto *ParentBB = TI.getParent();
BranchInst *BI = dyn_cast<BranchInst>(&TI);
SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
OuterExitL = NewOuterExitL;
}
+ // At this point, we're definitely going to unswitch something so invalidate
+ // any cached information in ScalarEvolution for the outer most loop
+ // containing an exit block and all nested loops.
+ if (SE) {
+ if (OuterExitL)
+ SE->forgetLoop(OuterExitL);
+ else
+ SE->forgetTopmostLoop(&L);
+ }
+
// If the edge from this terminator to a successor dominates that successor,
// store a map from each block in its dominator subtree to it. This lets us
// tell when cloning for a particular successor if a block is dominated by
return Cost;
}
-static bool unswitchBestCondition(
- Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
- TargetTransformInfo &TTI,
- function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
+static bool
+unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
+ AssumptionCache &AC, TargetTransformInfo &TTI,
+ function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
+ ScalarEvolution *SE) {
// Collect all invariant conditions within this loop (as opposed to an inner
// loop which would be handled when visiting that inner loop).
SmallVector<std::pair<TerminatorInst *, TinyPtrVector<Value *>>, 4>
<< BestUnswitchCost << ") terminator: " << *BestUnswitchTI
<< "\n");
return unswitchNontrivialInvariants(
- L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB);
+ L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE);
}
/// Unswitch control flow predicated on loop invariant conditions.
/// require duplicating any part of the loop) out of the loop body. It then
/// looks at other loop invariant control flows and tries to unswitch those as
/// well by cloning the loop if the result is small enough.
-static bool
-unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
- TargetTransformInfo &TTI, bool NonTrivial,
- function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
+///
+/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also
+/// updated based on the unswitch.
+///
+/// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is
+/// true, we will attempt to do non-trivial unswitching as well as trivial
+/// unswitching.
+///
+/// The `UnswitchCB` callback provided will be run after unswitching is
+/// complete, with the first parameter set to `true` if the provided loop
+/// remains a loop, and a list of new sibling loops created.
+///
+/// If `SE` is non-null, we will update that analysis based on the unswitching
+/// done.
+static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
+ AssumptionCache &AC, TargetTransformInfo &TTI,
+ bool NonTrivial,
+ function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
+ ScalarEvolution *SE) {
assert(L.isRecursivelyLCSSAForm(DT, LI) &&
"Loops must be in LCSSA form before unswitching.");
bool Changed = false;
return false;
// Try trivial unswitch first before loop over other basic blocks in the loop.
- if (unswitchAllTrivialConditions(L, DT, LI)) {
+ if (unswitchAllTrivialConditions(L, DT, LI, SE)) {
// If we unswitched successfully we will want to clean up the loop before
// processing it further so just mark it as unswitched and return.
UnswitchCB(/*CurrentLoopValid*/ true, {});
// Try to unswitch the best invariant condition. We prefer this full unswitch to
// a partial unswitch when possible below the threshold.
- if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB))
+ if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE))
return true;
// No other opportunities to unswitch.
U.markLoopAsDeleted(L, LoopName);
};
- if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial,
- UnswitchCB))
+ if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB,
+ &AR.SE))
return PreservedAnalyses::all();
// Historically this pass has had issues with the dominator tree so verify it
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
+ auto *SE = SEWP ? &SEWP->getSE() : nullptr;
+
auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid,
ArrayRef<Loop *> NewLoops) {
// If we did a non-trivial unswitch, we have added new (cloned) loops.
LPM.markLoopAsDeleted(*L);
};
- bool Changed =
- unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB);
+ bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE);
// If anything was unswitched, also clear any cached information about this
// loop.
--- /dev/null
+; RUN: opt -passes='print<scalar-evolution>,loop(unswitch,loop-instsimplify),print<scalar-evolution>' -enable-nontrivial-unswitch -S < %s 2>%t.scev | FileCheck %s
+; RUN: FileCheck %s --check-prefix=SCEV < %t.scev
+
+target triple = "x86_64-unknown-linux-gnu"
+
+declare void @f()
+
+; Check that trivially unswitching an inner loop resets both the inner and outer
+; loop trip count.
+define void @test1(i32 %n, i32 %m, i1 %cond) {
+; Check that SCEV has no trip count before unswitching.
+; SCEV-LABEL: Determining loop execution counts for: @test1
+; SCEV: Loop %inner_loop_begin: <multiple exits> Unpredictable backedge-taken count.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; Now check that after unswitching and simplifying instructions we get clean
+; backedge-taken counts.
+; SCEV-LABEL: Determining loop execution counts for: @test1
+; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m))<nsw>
+; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n))<nsw>
+;
+; And verify the code matches what we expect.
+; CHECK-LABEL: define void @test1(
+entry:
+ br label %outer_loop_begin
+; Ensure the outer loop didn't get unswitched.
+; CHECK: entry:
+; CHECK-NEXT: br label %outer_loop_begin
+
+outer_loop_begin:
+ %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ]
+ ; Block unswitching of the outer loop with a noduplicate call.
+ call void @f() noduplicate
+ br label %inner_loop_begin
+; Ensure the inner loop got unswitched into the outer loop.
+; CHECK: outer_loop_begin:
+; CHECK-NEXT: %{{.*}} = phi i32
+; CHECK-NEXT: call void @f()
+; CHECK-NEXT: br i1 %cond,
+
+inner_loop_begin:
+ %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ]
+ br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit
+
+inner_loop_latch:
+ %j.next = add nsw i32 %j, 1
+ %j.cmp = icmp slt i32 %j.next, %m
+ br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit
+
+inner_loop_early_exit:
+ %j.lcssa = phi i32 [ %i, %inner_loop_begin ]
+ br label %outer_loop_latch
+
+inner_loop_late_exit:
+ br label %outer_loop_latch
+
+outer_loop_latch:
+ %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ]
+ %i.next = add nsw i32 %i.phi, 1
+ %i.cmp = icmp slt i32 %i.next, %n
+ br i1 %i.cmp, label %outer_loop_begin, label %exit
+
+exit:
+ ret void
+}
+
+; Check that trivially unswitching an inner loop resets both the inner and outer
+; loop trip count.
+define void @test2(i32 %n, i32 %m, i32 %cond) {
+; Check that SCEV has no trip count before unswitching.
+; SCEV-LABEL: Determining loop execution counts for: @test2
+; SCEV: Loop %inner_loop_begin: <multiple exits> Unpredictable backedge-taken count.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; Now check that after unswitching and simplifying instructions we get clean
+; backedge-taken counts.
+; SCEV-LABEL: Determining loop execution counts for: @test2
+; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m))<nsw>
+; FIXME: The following backedge taken count should be known but isn't apparently
+; just because of a switch in the outer loop.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; CHECK-LABEL: define void @test2(
+entry:
+ br label %outer_loop_begin
+; Ensure the outer loop didn't get unswitched.
+; CHECK: entry:
+; CHECK-NEXT: br label %outer_loop_begin
+
+outer_loop_begin:
+ %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ]
+ ; Block unswitching of the outer loop with a noduplicate call.
+ call void @f() noduplicate
+ br label %inner_loop_begin
+; Ensure the inner loop got unswitched into the outer loop.
+; CHECK: outer_loop_begin:
+; CHECK-NEXT: %{{.*}} = phi i32
+; CHECK-NEXT: call void @f()
+; CHECK-NEXT: switch i32 %cond,
+
+inner_loop_begin:
+ %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ]
+ switch i32 %cond, label %inner_loop_early_exit [
+ i32 1, label %inner_loop_latch
+ i32 2, label %inner_loop_latch
+ ]
+
+inner_loop_latch:
+ %j.next = add nsw i32 %j, 1
+ %j.cmp = icmp slt i32 %j.next, %m
+ br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit
+
+inner_loop_early_exit:
+ %j.lcssa = phi i32 [ %i, %inner_loop_begin ]
+ br label %outer_loop_latch
+
+inner_loop_late_exit:
+ br label %outer_loop_latch
+
+outer_loop_latch:
+ %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ]
+ %i.next = add nsw i32 %i.phi, 1
+ %i.cmp = icmp slt i32 %i.next, %n
+ br i1 %i.cmp, label %outer_loop_begin, label %exit
+
+exit:
+ ret void
+}
+
+; Check that non-trivial unswitching of a branch in an inner loop into the outer
+; loop invalidates both inner and outer.
+define void @test3(i32 %n, i32 %m, i1 %cond) {
+; Check that SCEV has no trip count before unswitching.
+; SCEV-LABEL: Determining loop execution counts for: @test3
+; SCEV: Loop %inner_loop_begin: <multiple exits> Unpredictable backedge-taken count.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; Now check that after unswitching and simplifying instructions we get clean
+; backedge-taken counts.
+; SCEV-LABEL: Determining loop execution counts for: @test3
+; SCEV: Loop %inner_loop_begin{{.*}}: backedge-taken count is (-1 + (1 smax %m))<nsw>
+; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n))<nsw>
+;
+; And verify the code matches what we expect.
+; CHECK-LABEL: define void @test3(
+entry:
+ br label %outer_loop_begin
+; Ensure the outer loop didn't get unswitched.
+; CHECK: entry:
+; CHECK-NEXT: br label %outer_loop_begin
+
+outer_loop_begin:
+ %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ]
+ ; Block unswitching of the outer loop with a noduplicate call.
+ call void @f() noduplicate
+ br label %inner_loop_begin
+; Ensure the inner loop got unswitched into the outer loop.
+; CHECK: outer_loop_begin:
+; CHECK-NEXT: %{{.*}} = phi i32
+; CHECK-NEXT: call void @f()
+; CHECK-NEXT: br i1 %cond,
+
+inner_loop_begin:
+ %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ]
+ %j.tmp = add nsw i32 %j, 1
+ br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit
+
+inner_loop_latch:
+ %j.next = add nsw i32 %j, 1
+ %j.cmp = icmp slt i32 %j.next, %m
+ br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit
+
+inner_loop_early_exit:
+ %j.lcssa = phi i32 [ %j.tmp, %inner_loop_begin ]
+ br label %outer_loop_latch
+
+inner_loop_late_exit:
+ br label %outer_loop_latch
+
+outer_loop_latch:
+ %inc.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ 1, %inner_loop_late_exit ]
+ %i.next = add nsw i32 %i, %inc.phi
+ %i.cmp = icmp slt i32 %i.next, %n
+ br i1 %i.cmp, label %outer_loop_begin, label %exit
+
+exit:
+ ret void
+}