[SimpleLoopUnswitch] Fix SCEV invalidation issue
authorBjorn Pettersson <bjorn.a.pettersson@ericsson.com>
Tue, 28 Mar 2023 16:07:52 +0000 (18:07 +0200)
committerBjorn Pettersson <bjorn.a.pettersson@ericsson.com>
Thu, 6 Apr 2023 07:46:42 +0000 (09:46 +0200)
This patch is making sure that we use getTopMostExitingLoop when
finding out which loops to forget, when dealing with
unswitchNontrivialInvariants and unswitchTrivialSwitch. It seems
to at least be needed for unswitchNontrivialInvariants as detected
by the included test case.

Note that unswitchTrivialBranch already used getTopMostExitingLoop.
This was done in commit 4a9cde5a791cd49b96993e6. The commit
message in that commit says "If the patch makes sense, I will also
update those places to a similar approach ...", referring to these
functions mentioned above. As far as I can tell that never happened,
but this is an attempt to finally fix that.

Fixes https://github.com/llvm/llvm-project/issues/61080

Differential Revision: https://reviews.llvm.org/D147058

llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
llvm/test/Transforms/SimpleLoopUnswitch/update-scev-3.ll [new file with mode: 0644]

index 3cc4729..1e2c9c9 100644 (file)
@@ -475,10 +475,10 @@ static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader,
 // Return the top-most loop containing ExitBB and having ExitBB as exiting block
 // or the loop containing ExitBB, if there is no parent loop containing ExitBB
 // as exiting block.
-static const Loop *getTopMostExitingLoop(const BasicBlock *ExitBB,
-                                         const LoopInfo &LI) {
-  const Loop *TopMost = LI.getLoopFor(ExitBB);
-  const Loop *Current = TopMost;
+static Loop *getTopMostExitingLoop(const BasicBlock *ExitBB,
+                                   const LoopInfo &LI) {
+  Loop *TopMost = LI.getLoopFor(ExitBB);
+  Loop *Current = TopMost;
   while (Current) {
     if (Current->isLoopExiting(ExitBB))
       TopMost = Current;
@@ -792,14 +792,14 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
 
   if (DefaultExitBB) {
     // Check the loop containing this exit.
-    Loop *ExitL = LI.getLoopFor(DefaultExitBB);
+    Loop *ExitL = getTopMostExitingLoop(DefaultExitBB, LI);
     if (!ExitL || ExitL->contains(OuterL))
       OuterL = ExitL;
   }
   for (unsigned Index : ExitCaseIndices) {
     auto CaseI = SI.case_begin() + Index;
     // Compute the outer loop from this exit.
-    Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor());
+    Loop *ExitL = getTopMostExitingLoop(CaseI->getCaseSuccessor(), LI);
     if (!ExitL || ExitL->contains(OuterL))
       OuterL = ExitL;
   }
@@ -2207,7 +2207,9 @@ static void unswitchNontrivialInvariants(
   SmallVector<BasicBlock *, 4> ExitBlocks;
   L.getUniqueExitBlocks(ExitBlocks);
   for (auto *ExitBB : ExitBlocks) {
-    Loop *NewOuterExitL = LI.getLoopFor(ExitBB);
+    // ExitBB can be an exit block for several levels in the loop nest. Make
+    // sure we find the top most.
+    Loop *NewOuterExitL = getTopMostExitingLoop(ExitBB, LI);
     if (!NewOuterExitL) {
       // We exited the entire nest with this block, so we're done.
       OuterExitL = nullptr;
diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/update-scev-3.ll b/llvm/test/Transforms/SimpleLoopUnswitch/update-scev-3.ll
new file mode 100644 (file)
index 0000000..ef00d7e
--- /dev/null
@@ -0,0 +1,186 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -passes='print<scalar-evolution>,simple-loop-unswitch<nontrivial>' -verify-scev -S | FileCheck %s
+
+; This is a reproducer for https://github.com/llvm/llvm-project/issues/61080
+;
+; Note that the print<scalar-evolution> in the beginning of the pipeline is
+; needed for reproducing the original problem, as we need something that
+; calculate SCEV before the loop unswitch, to verify that we invalidate SCEV
+; correctly while doing the unswitch.
+;
+; Verify that we no longer hit that assert. Also perform checks to show the IR
+; transformation that is going on in this test.
+
+
+define i32 @foo(i1 %not) {
+; CHECK-LABEL: define i32 @foo
+; CHECK-SAME: (i1 [[NOT:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[FALSE:%.*]] = and i1 true, false
+; CHECK-NEXT:    br i1 [[NOT]], label [[ENTRY_SPLIT_US:%.*]], label [[ENTRY_SPLIT:%.*]]
+; CHECK:       entry.split.us:
+; CHECK-NEXT:    br i1 [[FALSE]], label [[ENTRY_SPLIT_US_SPLIT_US:%.*]], label [[ENTRY_SPLIT_US_SPLIT:%.*]]
+; CHECK:       entry.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND_US_US:%.*]]
+; CHECK:       for.cond.us.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_US_US_US:%.*]]
+; CHECK:       for.cond.split.us.us.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_US_SPLIT_US_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       for.cond.split.us.split.us.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_US_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       entry.split.us.split:
+; CHECK-NEXT:    br label [[FOR_COND_US:%.*]]
+; CHECK:       for.cond.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_US_US:%.*]]
+; CHECK:       for.inc11.us:
+; CHECK-NEXT:    br label [[FOR_COND_US]]
+; CHECK:       for.cond.split.us.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_US_SPLIT_US11:%.*]]
+; CHECK:       for.cond5.preheader.us.us9:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_US_US_US10:%.*]]
+; CHECK:       for.inc8.us.us:
+; CHECK-NEXT:    br i1 false, label [[FOR_INC8_FOR_COND5_PREHEADER_CRIT_EDGE_US_US:%.*]], label [[FOR_INC11_SPLIT_US_US:%.*]]
+; CHECK:       for.inc8.for.cond5.preheader_crit_edge.us.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_US_US9:%.*]]
+; CHECK:       for.end.us.us:
+; CHECK-NEXT:    br i1 false, label [[FOR_INC8_US_US:%.*]], label [[CLEANUP15_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       for.cond5.preheader.split.us.us.us10:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_US_SPLIT_US7_US:%.*]]
+; CHECK:       for.body7.us.us4.us:
+; CHECK-NEXT:    br label [[HANDLER_POINTER_OVERFLOW_US_US5_US:%.*]]
+; CHECK:       handler.pointer_overflow.us.us5.us:
+; CHECK-NEXT:    br label [[CONT_US_US6_US:%.*]]
+; CHECK:       cont.us.us6.us:
+; CHECK-NEXT:    br label [[FOR_END_SPLIT_US_US_US:%.*]]
+; CHECK:       for.end.split.us.us.us:
+; CHECK-NEXT:    br label [[FOR_END_US_US:%.*]]
+; CHECK:       for.cond5.preheader.split.us.split.us7.us:
+; CHECK-NEXT:    br label [[FOR_BODY7_US_US4_US:%.*]]
+; CHECK:       for.inc11.split.us.us:
+; CHECK-NEXT:    br label [[FOR_INC11_US:%.*]]
+; CHECK:       for.cond.split.us.split.us11:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_US_US9]]
+; CHECK:       for.cond.split.us.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       cleanup15.split.us.split.us:
+; CHECK-NEXT:    br label [[CLEANUP15_SPLIT_US:%.*]]
+; CHECK:       entry.split:
+; CHECK-NEXT:    br i1 [[FALSE]], label [[ENTRY_SPLIT_SPLIT_US:%.*]], label [[ENTRY_SPLIT_SPLIT:%.*]]
+; CHECK:       entry.split.split.us:
+; CHECK-NEXT:    br label [[FOR_COND_US12:%.*]]
+; CHECK:       for.cond.us12:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_US:%.*]]
+; CHECK:       for.cond.split.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       for.cond.split.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_SPLIT_US:%.*]]
+; CHECK:       entry.split.split:
+; CHECK-NEXT:    br label [[FOR_COND:%.*]]
+; CHECK:       for.cond:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT:%.*]]
+; CHECK:       for.cond.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_US_US:%.*]]
+; CHECK:       for.cond5.preheader.us.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_US_US_US:%.*]]
+; CHECK:       for.cond5.preheader.split.us.us.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_US_SPLIT_US_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       for.cond5.preheader.split.us.split.us.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_US_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       cleanup15.split.us:
+; CHECK-NEXT:    br label [[CLEANUP15:%.*]]
+; CHECK:       for.cond5.preheader.split.us.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       for.cond.split:
+; CHECK-NEXT:    br label [[FOR_COND_SPLIT_SPLIT:%.*]]
+; CHECK:       for.cond.split.split.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_US8:%.*]]
+; CHECK:       for.cond5.preheader.us8:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_US:%.*]]
+; CHECK:       for.cond5.preheader.split.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_SPLIT_US_SPLIT_US:%.*]]
+; CHECK:       for.cond5.preheader.split.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_SPLIT_US:%.*]]
+; CHECK:       for.cond.split.split:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER:%.*]]
+; CHECK:       for.cond5.preheader:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT:%.*]]
+; CHECK:       for.cond5.preheader.split.us.split.us:
+; CHECK-NEXT:    br label [[FOR_BODY7_US_US:%.*]]
+; CHECK:       for.body7.us.us:
+; CHECK-NEXT:    br label [[HANDLER_POINTER_OVERFLOW_US_US:%.*]]
+; CHECK:       handler.pointer_overflow.us.us:
+; CHECK-NEXT:    br label [[CONT_US_US:%.*]]
+; CHECK:       cont.us.us:
+; CHECK-NEXT:    br label [[CONT_FOR_BODY7_CRIT_EDGE_US_US:%.*]]
+; CHECK:       cont.for.body7_crit_edge.us.us:
+; CHECK-NEXT:    br label [[FOR_BODY7_US_US]]
+; CHECK:       for.cond5.preheader.split:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER_SPLIT_SPLIT:%.*]]
+; CHECK:       for.cond5.preheader.split.split.us:
+; CHECK-NEXT:    br label [[FOR_BODY7_US1:%.*]]
+; CHECK:       for.body7.us1:
+; CHECK-NEXT:    br label [[CONT_US2:%.*]]
+; CHECK:       cont.us2:
+; CHECK-NEXT:    br label [[CONT_FOR_BODY7_CRIT_EDGE_US3:%.*]]
+; CHECK:       cont.for.body7_crit_edge.us3:
+; CHECK-NEXT:    br label [[FOR_BODY7_US1]]
+; CHECK:       for.cond5.preheader.split.split:
+; CHECK-NEXT:    br label [[FOR_BODY7:%.*]]
+; CHECK:       for.body7:
+; CHECK-NEXT:    br label [[CONT:%.*]]
+; CHECK:       cont:
+; CHECK-NEXT:    br label [[FOR_END_SPLIT:%.*]]
+; CHECK:       for.end.split:
+; CHECK-NEXT:    br label [[FOR_END:%.*]]
+; CHECK:       for.end:
+; CHECK-NEXT:    br i1 false, label [[FOR_INC8:%.*]], label [[CLEANUP15_SPLIT:%.*]]
+; CHECK:       for.inc8:
+; CHECK-NEXT:    br i1 false, label [[FOR_INC8_FOR_COND5_PREHEADER_CRIT_EDGE:%.*]], label [[FOR_INC11_SPLIT:%.*]]
+; CHECK:       for.inc8.for.cond5.preheader_crit_edge:
+; CHECK-NEXT:    br label [[FOR_COND5_PREHEADER]]
+; CHECK:       for.inc11.split:
+; CHECK-NEXT:    br label [[FOR_INC11:%.*]]
+; CHECK:       for.inc11:
+; CHECK-NEXT:    br label [[FOR_COND]]
+; CHECK:       cleanup15.split:
+; CHECK-NEXT:    br label [[CLEANUP15]]
+; CHECK:       cleanup15:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %false = and i1 1, 0
+  br label %for.cond
+
+for.cond:                                         ; preds = %for.inc11, %entry
+  br label %for.cond5.preheader
+
+for.cond5.preheader:                              ; preds = %for.inc8.for.cond5.preheader_crit_edge, %for.cond
+  br label %for.body7
+
+for.body7:                                        ; preds = %cont.for.body7_crit_edge, %for.cond5.preheader
+  br i1 %not, label %handler.pointer_overflow, label %cont
+
+handler.pointer_overflow:                         ; preds = %for.body7
+  br label %cont
+
+cont:                                             ; preds = %handler.pointer_overflow, %for.body7
+  br i1 %false, label %cont.for.body7_crit_edge, label %for.end
+
+cont.for.body7_crit_edge:                         ; preds = %cont
+  br label %for.body7
+
+for.end:                                          ; preds = %cont
+  br i1 %false, label %for.inc8, label %cleanup15
+
+for.inc8:                                         ; preds = %for.end
+  br i1 %false, label %for.inc8.for.cond5.preheader_crit_edge, label %for.inc11
+
+for.inc8.for.cond5.preheader_crit_edge:           ; preds = %for.inc8
+  br label %for.cond5.preheader
+
+for.inc11:                                        ; preds = %for.inc8
+  br label %for.cond
+
+cleanup15:                                        ; preds = %for.end
+  ret i32 0
+}