[SCEV] Help getLoopInvariantExitCondDuringFirstIterations deal with complex `umin...
authorMax Kazantsev <mkazantsev@azul.com>
Wed, 21 Dec 2022 05:47:36 +0000 (12:47 +0700)
committerMax Kazantsev <mkazantsev@azul.com>
Wed, 21 Dec 2022 11:12:17 +0000 (18:12 +0700)
Recent improvements in symbolic exit count computation revealed some problems with
SCEV's ability to find invariant predicate during first iterations. Ultimately it is based on its
ability to prove some facts for value on the last iteration. This last value, when it includes
`umin` as part of exit count, isn't always simplified enough. The motivating example is following:

https://github.com/llvm/llvm-project/issues/59615

Could not prove:
```
        Pred = 36, LHS = (-1 + (-1 * (2147483645 umin (-1 + %var)<nsw>))<nsw> + %var), RHS = %var
        FoundPred = 36, FoundLHS = {1,+,1}<nuw><nsw><%bb3>, FoundRHS = %var
```
Can prove:
```
        Pred = 36, LHS = (-1 + (-1 * (-1 + %var)<nsw>)<nsw> + %var), RHS = %var
        FoundPred = 36, FoundLHS = {1,+,1}<nuw><nsw><%bb3>, FoundRHS = %var
```

Here ` (2147483645 umin (-1 + %var)<nsw>)` is exit count composed of two parts from
two different exits: `2147483645 ` and `(-1 + %var)<nsw>`. When it was only one (latter)
analyzeable exit, for it everything was easily provable. Unfortunately, in general case `umin`
in one of `add`'s operands doesn't guarantee that the whole sum reduces, especially in presence
of negative steps and lack of `nuw`. I don't think there is a generic legal way to somehow play
around this `umin`.

So the ad-hoc solution is following: if we failed to find an equivalent predicate that is invariant
during first `MaxIter` iterations, and `MaxIter = umin(a, b, c...)`, try to find solution for at least one
of `a`, `b`, `c`... Because they all are `uge` than `MaxIter`, whatever is true during `a (b, c)` iterations
is also true during `MaxIter` iterations.

Differential Revision: https://reviews.llvm.org/D140456
Reviewed By: nikic

llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Transforms/IndVarSimplify/X86/pr59615.ll

index 917b6d1..eb19488 100644 (file)
@@ -1122,6 +1122,11 @@ public:
                                                 const Instruction *CtxI,
                                                 const SCEV *MaxIter);
 
+  std::optional<LoopInvariantPredicate>
+  getLoopInvariantExitCondDuringFirstIterationsImpl(
+      ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+      const Instruction *CtxI, const SCEV *MaxIter);
+
   /// Simplify LHS and RHS in a comparison with predicate Pred. Return true
   /// iff any changes were made. If the operands are provably equal or
   /// unequal, LHS and RHS are set to the same value and Pred is set to either
index 3c445aa..34a9cdd 100644 (file)
@@ -11046,6 +11046,26 @@ std::optional<ScalarEvolution::LoopInvariantPredicate>
 ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
     const Instruction *CtxI, const SCEV *MaxIter) {
+  if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl(
+          Pred, LHS, RHS, L, CtxI, MaxIter))
+    return LIP;
+  if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
+    // Number of iterations expressed as UMIN isn't always great for expressing
+    // the value on the last iteration. If the straightforward approach didn't
+    // work, try the following trick: if the a predicate is invariant for X, it
+    // is also invariant for umin(X, ...). So try to find something that works
+    // among subexpressions of MaxIter expressed as umin.
+    for (auto *Op : UMin->operands())
+      if (auto LIP = getLoopInvariantExitCondDuringFirstIterationsImpl(
+              Pred, LHS, RHS, L, CtxI, Op))
+        return LIP;
+  return std::nullopt;
+}
+
+std::optional<ScalarEvolution::LoopInvariantPredicate>
+ScalarEvolution::getLoopInvariantExitCondDuringFirstIterationsImpl(
+    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+    const Instruction *CtxI, const SCEV *MaxIter) {
   // Try to prove the following set of facts:
   // - The predicate is monotonic in the iteration space.
   // - If the check does not fail on the 1st iteration:
index 71f9d90..17b7b9d 100644 (file)
@@ -8,35 +8,30 @@ define void @test() {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    [[VAR:%.*]] = load atomic i32, ptr addrspace(1) poison unordered, align 8, !range [[RNG0:![0-9]+]], !invariant.load !1, !noundef !1
-; CHECK-NEXT:    [[VAR1:%.*]] = add nsw i32 [[VAR]], -1
 ; CHECK-NEXT:    [[VAR2:%.*]] = icmp eq i32 [[VAR]], 0
 ; CHECK-NEXT:    br i1 [[VAR2]], label [[BB18:%.*]], label [[BB19:%.*]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ 0, [[BB19]] ], [ [[INDVARS_IV_NEXT:%.*]], [[BB12:%.*]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i64 [[TMP3:%.*]], [[INDVARS_IV]]
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32
-; CHECK-NEXT:    [[VAR6:%.*]] = icmp ult i32 [[TMP1]], [[VAR]]
-; CHECK-NEXT:    br i1 [[VAR6]], label [[BB8:%.*]], label [[BB7:%.*]]
+; CHECK-NEXT:    br i1 true, label [[BB8:%.*]], label [[BB7:%.*]]
 ; CHECK:       bb7:
 ; CHECK-NEXT:    ret void
 ; CHECK:       bb8:
 ; CHECK-NEXT:    [[VAR9:%.*]] = load atomic i32, ptr addrspace(1) poison unordered, align 8, !range [[RNG0]], !invariant.load !1, !noundef !1
-; CHECK-NEXT:    [[TMP2:%.*]] = zext i32 [[VAR9]] to i64
-; CHECK-NEXT:    [[VAR10:%.*]] = icmp ult i64 [[INDVARS_IV]], [[TMP2]]
+; CHECK-NEXT:    [[TMP0:%.*]] = zext i32 [[VAR9]] to i64
+; CHECK-NEXT:    [[VAR10:%.*]] = icmp ult i64 [[INDVARS_IV]], [[TMP0]]
 ; CHECK-NEXT:    br i1 [[VAR10]], label [[BB12]], label [[BB11:%.*]]
 ; CHECK:       bb11:
 ; CHECK-NEXT:    ret void
 ; CHECK:       bb12:
 ; CHECK-NEXT:    [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 1
-; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT]], [[WIDE_TRIP_COUNT:%.*]]
+; CHECK-NEXT:    [[EXITCOND:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT]], [[TMP1:%.*]]
 ; CHECK-NEXT:    br i1 [[EXITCOND]], label [[BB3:%.*]], label [[BB17:%.*]]
 ; CHECK:       bb17:
 ; CHECK-NEXT:    unreachable
 ; CHECK:       bb18:
 ; CHECK-NEXT:    ret void
 ; CHECK:       bb19:
-; CHECK-NEXT:    [[TMP3]] = sext i32 [[VAR1]] to i64
-; CHECK-NEXT:    [[WIDE_TRIP_COUNT]] = zext i32 [[VAR]] to i64
+; CHECK-NEXT:    [[TMP1]] = zext i32 [[VAR]] to i64
 ; CHECK-NEXT:    br label [[BB3]]
 ;
 bb: