Revert "[MergeICmps] Adapt to non-eq comparisons"
authorVitaly Buka <vitalybuka@google.com>
Fri, 13 Jan 2023 02:09:15 +0000 (18:09 -0800)
committerVitaly Buka <vitalybuka@google.com>
Fri, 13 Jan 2023 02:14:47 +0000 (18:14 -0800)
Breaks ubsan build, details in D141188.

This reverts commit 3ac2b3a4f9effc9f79822e770f209fd70ff66362.

llvm/lib/Transforms/Scalar/MergeICmps.cpp
llvm/test/Transforms/MergeICmps/X86/pr59740.ll [deleted file]

index bcb95f8..bcedb05 100644 (file)
@@ -330,10 +330,10 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
 
 // Visit the given comparison block. If this is a comparison between two valid
 // BCE atoms, returns the comparison.
-std::optional<BCECmpBlock>
-visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate,
-              Value *const Val, BasicBlock *const Block,
-              const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) {
+std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
+                                         BasicBlock *const Block,
+                                         const BasicBlock *const PhiBlock,
+                                         BaseIdentifier &BaseId) {
   if (Block->empty())
     return std::nullopt;
   auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
@@ -348,27 +348,15 @@ visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate,
     // that this does not mean that this is the last incoming value, blocks
     // can be reordered).
     Cond = Val;
-    const auto *const ConstBase = cast<ConstantInt>(Baseline);
-    assert(ConstBase->getType()->isIntegerTy(1) &&
-           "Select condition is not an i1?");
-    ExpectedPredicate =
-        ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
-
-    // Remember the correct predicate.
-    Predicate = ExpectedPredicate;
+    ExpectedPredicate = ICmpInst::ICMP_EQ;
   } else {
-    // All the incoming values must be consistent.
-    if (Baseline != Val)
-      return std::nullopt;
     // In this case, we expect a constant incoming value (the comparison is
     // chained).
     const auto *const Const = cast<ConstantInt>(Val);
-    assert(Const->getType()->isIntegerTy(1) &&
-           "Incoming value is not an i1?");
     LLVM_DEBUG(dbgs() << "const\n");
-    if (!Const->isZero() && !Const->isOne())
+    if (!Const->isZero())
       return std::nullopt;
-    LLVM_DEBUG(dbgs() << *Const << "\n");
+    LLVM_DEBUG(dbgs() << "false\n");
     assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
     BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
     Cond = BranchI->getCondition();
@@ -429,8 +417,6 @@ private:
   std::vector<ContiguousBlocks> MergedBlocks_;
   // The original entry block (before sorting);
   BasicBlock *EntryBlock_;
-  // Remember the predicate type of the chain.
-  ICmpInst::Predicate Predicate_;
 };
 
 static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
@@ -489,13 +475,10 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
   // Now look inside blocks to check for BCE comparisons.
   std::vector<BCECmpBlock> Comparisons;
   BaseIdentifier BaseId;
-  Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]);
-  Predicate_ = CmpInst::BAD_ICMP_PREDICATE;
   for (BasicBlock *const Block : Blocks) {
     assert(Block && "invalid block");
-    std::optional<BCECmpBlock> Comparison =
-        visitCmpBlock(Baseline, Predicate_, Phi.getIncomingValueForBlock(Block),
-                      Block, Phi.getParent(), BaseId);
+    std::optional<BCECmpBlock> Comparison = visitCmpBlock(
+        Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId);
     if (!Comparison) {
       LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
       return;
@@ -619,8 +602,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
                                     BasicBlock *const InsertBefore,
                                     BasicBlock *const NextCmpBlock,
                                     PHINode &Phi, const TargetLibraryInfo &TLI,
-                                    AliasAnalysis &AA, DomTreeUpdater &DTU,
-                                    ICmpInst::Predicate Predicate) {
+                                    AliasAnalysis &AA, DomTreeUpdater &DTU) {
   assert(!Comparisons.empty() && "merging zero comparisons");
   LLVMContext &Context = NextCmpBlock->getContext();
   const BCECmpBlock &FirstCmp = Comparisons[0];
@@ -641,7 +623,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
   else
     Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
 
-  Value *ICmpValue = nullptr;
+  Value *IsEqual = nullptr;
   LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
                     << BB->getName() << "\n");
 
@@ -662,7 +644,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
     Value *const RhsLoad =
         Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
     // There are no blocks to merge, just do the comparison.
-    ICmpValue = Builder.CreateICmp(Predicate, LhsLoad, RhsLoad);
+    IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
   } else {
     const unsigned TotalSizeBits = std::accumulate(
         Comparisons.begin(), Comparisons.end(), 0u,
@@ -678,8 +660,8 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
         Lhs, Rhs,
         ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8),
         Builder, DL, &TLI);
-    ICmpValue = Builder.CreateICmp(
-        Predicate, MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
+    IsEqual = Builder.CreateICmpEQ(
+        MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
   }
 
   BasicBlock *const PhiBB = Phi.getParent();
@@ -687,11 +669,11 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
   if (NextCmpBlock == PhiBB) {
     // Continue to phi, passing it the comparison result.
     Builder.CreateBr(PhiBB);
-    Phi.addIncoming(ICmpValue, BB);
+    Phi.addIncoming(IsEqual, BB);
     DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
   } else {
     // Continue to next block if equal, exit to phi else.
-    Builder.CreateCondBr(ICmpValue, NextCmpBlock, PhiBB);
+    Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
     Phi.addIncoming(ConstantInt::getFalse(Context), BB);
     DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
                       {DominatorTree::Insert, BB, PhiBB}});
@@ -709,11 +691,9 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
   // so that the next block is always available to branch to.
   BasicBlock *InsertBefore = EntryBlock_;
   BasicBlock *NextCmpBlock = Phi_.getParent();
-  assert(Predicate_ != CmpInst::BAD_ICMP_PREDICATE &&
-         "Got the chain of comparisons");
   for (const auto &Blocks : reverse(MergedBlocks_)) {
     InsertBefore = NextCmpBlock = mergeComparisons(
-        Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU, Predicate_);
+        Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
   }
 
   // Replace the original cmp chain with the new cmp chain by pointing all
diff --git a/llvm/test/Transforms/MergeICmps/X86/pr59740.ll b/llvm/test/Transforms/MergeICmps/X86/pr59740.ll
deleted file mode 100644 (file)
index 6b46325..0000000
+++ /dev/null
@@ -1,86 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=mergeicmps -verify-dom-info -S -mtriple=x86_64-unknown-unknown | FileCheck %s
-
-%struct.S = type { i8, i8, i8, i8 }
-
-define noundef i1 @_Z2neR1SS0_(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
-; CHECK-LABEL: @_Z2neR1SS0_(
-; CHECK-NEXT:  "bb0+bb1+bb2+bb3":
-; CHECK-NEXT:    [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[S0:%.*]], ptr [[S1:%.*]], i64 4)
-; CHECK-NEXT:    [[TMP0:%.*]] = icmp ne i32 [[MEMCMP]], 0
-; CHECK-NEXT:    br label [[BB4:%.*]]
-; CHECK:       bb4:
-; CHECK-NEXT:    ret i1 [[TMP0]]
-;
-bb0:
-  %v0 = load i8, ptr %s0, align 1
-  %v1 = load i8, ptr %s1, align 1
-  %cmp0 = icmp eq i8 %v0, %v1
-  br i1 %cmp0, label %bb1, label %bb4
-
-bb1:                                              ; preds = %bb0
-  %s2 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
-  %v2 = load i8, ptr %s2, align 1
-  %s3 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
-  %v3 = load i8, ptr %s3, align 1
-  %cmp1 = icmp eq i8 %v2, %v3
-  br i1 %cmp1, label %bb2, label %bb4
-
-bb2:                                             ; preds = %bb1
-  %s4 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 2
-  %v4 = load i8, ptr %s4, align 1
-  %s5 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 2
-  %v5 = load i8, ptr %s5, align 1
-  %cmp2 = icmp eq i8 %v4, %v5
-  br i1 %cmp2, label %bb3, label %bb4
-
-bb3:                                               ; preds = %bb2
-  %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 3
-  %v6 = load i8, ptr %s6, align 1
-  %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 3
-  %v7 = load i8, ptr %s7, align 1
-  %cmp3 = icmp ne i8 %v6, %v7
-  br label %bb4
-
-bb4:                                               ; preds = %bb0, %bb1, %bb2, %bb3
-  %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ true, %bb2 ], [ %cmp3, %bb3 ]
-  ret i1 %cmp
-}
-
-; Negative test: Incorrect const value in PHI node
-define noundef i1 @cmp_ne_incorrect_const(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
-; CHECK-LABEL: @cmp_ne_incorrect_const(
-; CHECK-NEXT:  bb0:
-; CHECK-NEXT:    [[V0:%.*]] = load i8, ptr [[S0:%.*]], align 1
-; CHECK-NEXT:    [[V1:%.*]] = load i8, ptr [[S1:%.*]], align 1
-; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i8 [[V0]], [[V1]]
-; CHECK-NEXT:    br i1 [[CMP0]], label [[BB1:%.*]], label [[BB2:%.*]]
-; CHECK:       bb1:
-; CHECK-NEXT:    [[S6:%.*]] = getelementptr inbounds [[STRUCT_S:%.*]], ptr [[S0]], i64 0, i32 1
-; CHECK-NEXT:    [[V6:%.*]] = load i8, ptr [[S6]], align 1
-; CHECK-NEXT:    [[S7:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[S1]], i64 0, i32 1
-; CHECK-NEXT:    [[V7:%.*]] = load i8, ptr [[S7]], align 1
-; CHECK-NEXT:    [[CMP3:%.*]] = icmp ne i8 [[V6]], [[V7]]
-; CHECK-NEXT:    br label [[BB2]]
-; CHECK:       bb2:
-; CHECK-NEXT:    [[CMP:%.*]] = phi i1 [ false, [[BB0:%.*]] ], [ [[CMP3]], [[BB1]] ]
-; CHECK-NEXT:    ret i1 [[CMP]]
-;
-bb0:
-  %v0 = load i8, ptr %s0, align 1
-  %v1 = load i8, ptr %s1, align 1
-  %cmp0 = icmp eq i8 %v0, %v1
-  br i1 %cmp0, label %bb1, label %bb2
-
-bb1:                                               ; preds = %bb0
-  %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
-  %v6 = load i8, ptr %s6, align 1
-  %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
-  %v7 = load i8, ptr %s7, align 1
-  %cmp3 = icmp ne i8 %v6, %v7
-  br label %bb2
-
-bb2:                                               ; preds = %bb0, %bb1
-  %cmp = phi i1 [ false, %bb0 ], [ %cmp3, %bb1 ]
-  ret i1 %cmp
-}