Re-enable "[SCEV] Prove implications of different type via truncation"
authorMax Kazantsev <mkazantsev@azul.com>
Wed, 28 Oct 2020 06:28:39 +0000 (13:28 +0700)
committerMax Kazantsev <mkazantsev@azul.com>
Wed, 28 Oct 2020 09:02:14 +0000 (16:02 +0700)
When we need to prove implication of expressions of different type width,
the default strategy is to widen everything to wider type and prove in this
type. This does not interact well with AddRecs with negative steps and
unsigned predicates: such AddRec will likely not have a `nuw` flag, and its
`zext` to wider type will not be an AddRec. In contraty, `trunc` of an AddRec
in some cases can easily be proved to be an `AddRec` too.

This patch introduces an alternative way to handling implications of different
type widths. If we can prove that wider type values actually fit in the narrow type,
we truncate them and prove the implication in narrow type.

The return was due to revert of underlying patch that this one depends on.

Unit test temporarily disabled because the required logic in SCEV is switched
off due to compile time reasons.

Differential Revision: https://reviews.llvm.org/D89548

llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/srem.ll
llvm/unittests/Analysis/ScalarEvolutionTest.cpp

index 17128fb9ad9d20383649c0784b4d2d2bb755fd34..fe6bfd520117f41d0d827216b1ac821096e5a834 100644 (file)
@@ -9794,6 +9794,25 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
   // Balance the types.
   if (getTypeSizeInBits(LHS->getType()) <
       getTypeSizeInBits(FoundLHS->getType())) {
+    // For unsigned and equality predicates, try to prove that both found
+    // operands fit into narrow unsigned range. If so, try to prove facts in
+    // narrow types.
+    if (!CmpInst::isSigned(FoundPred)) {
+      auto *NarrowType = LHS->getType();
+      auto *WideType = FoundLHS->getType();
+      auto BitWidth = getTypeSizeInBits(NarrowType);
+      const SCEV *MaxValue = getZeroExtendExpr(
+          getConstant(APInt::getMaxValue(BitWidth)), WideType);
+      if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) &&
+          isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) {
+        const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
+        const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
+        if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
+                                       TruncFoundRHS, Context))
+          return true;
+      }
+    }
+
     if (CmpInst::isSigned(Pred)) {
       LHS = getSignExtendExpr(LHS, FoundLHS->getType());
       RHS = getSignExtendExpr(RHS, FoundLHS->getType());
index 197437b51ca120188a942782c44fd63482190053..76e0d4a5ec5b5f3a88613c8a313bafbd2b8e955f 100644 (file)
@@ -29,7 +29,7 @@ define dso_local void @_Z4loopi(i32 %width) local_unnamed_addr #0 {
 ; CHECK-NEXT:    %add = add nsw i32 %2, %call
 ; CHECK-NEXT:    --> (%2 + %call) U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.cond: Variant }
 ; CHECK-NEXT:    %inc = add nsw i32 %i.0, 1
-; CHECK-NEXT:    --> {1,+,1}<nuw><%for.cond> U: [1,0) S: [1,0) Exits: (1 + %width) LoopDispositions: { %for.cond: Computable }
+; CHECK-NEXT:    --> {1,+,1}<nuw><%for.cond> U: full-set S: full-set Exits: (1 + %width) LoopDispositions: { %for.cond: Computable }
 ; CHECK-NEXT:  Determining loop execution counts for: @_Z4loopi
 ; CHECK-NEXT:  Loop %for.cond: backedge-taken count is %width
 ; CHECK-NEXT:  Loop %for.cond: max backedge-taken count is -1
index be8941838f71afb3b3886b9dd3984f90e8360097..909a140296ac62b85eebcd6d0065c3c132cd706c 100644 (file)
@@ -1316,4 +1316,51 @@ TEST_F(ScalarEvolutionsTest, UnsignedIsImpliedViaOperations) {
   });
 }
 
+TEST_F(ScalarEvolutionsTest, ProveImplicationViaNarrowing) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(
+      "define i32 @foo(i32 %start, i32* %q) { "
+      "entry: "
+      "  %wide.start = zext i32 %start to i64 "
+      "  br label %loop "
+      "loop: "
+      "  %wide.iv = phi i64 [%wide.start, %entry], [%wide.iv.next, %backedge] "
+      "  %iv = phi i32 [%start, %entry], [%iv.next, %backedge] "
+      "  %cond = icmp eq i64 %wide.iv, 0 "
+      "  br i1 %cond, label %exit, label %backedge "
+      "backedge: "
+      "  %iv.next = add i32 %iv, -1 "
+      "  %index = zext i32 %iv.next to i64 "
+      "  %load.addr = getelementptr i32, i32* %q, i64 %index "
+      "  %stop = load i32, i32* %load.addr "
+      "  %loop.cond = icmp eq i32 %stop, 0 "
+      "  %wide.iv.next = add nsw i64 %wide.iv, -1 "
+      "  br i1 %loop.cond, label %loop, label %failure "
+      "exit: "
+      "  ret i32 0 "
+      "failure: "
+      "  unreachable "
+      "} ",
+      Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    auto *IV = SE.getSCEV(getInstructionByName(F, "iv"));
+    auto *Zero = SE.getZero(IV->getType());
+    auto *Backedge = getInstructionByName(F, "iv.next")->getParent();
+    ASSERT_TRUE(Backedge);
+    (void)IV;
+    (void)Zero;
+    // FIXME: This can only be proved with turned on option
+    // scalar-evolution-use-expensive-range-sharpening which is currently off.
+    // Enable the check once it's switched true by default.
+    // EXPECT_TRUE(SE.isBasicBlockEntryGuardedByCond(Backedge,
+    //                                               ICmpInst::ICMP_UGT,
+    //                                               IV, Zero));
+  });
+}
+
 }  // end namespace llvm