// Visit the given comparison block. If this is a comparison between two valid
// BCE atoms, returns the comparison.
-std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
- BasicBlock *const Block,
- const BasicBlock *const PhiBlock,
- BaseIdentifier &BaseId) {
+std::optional<BCECmpBlock>
+visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate,
+ Value *const Val, BasicBlock *const Block,
+ const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) {
if (Block->empty())
return std::nullopt;
auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
// that this does not mean that this is the last incoming value, blocks
// can be reordered).
Cond = Val;
- ExpectedPredicate = ICmpInst::ICMP_EQ;
+ const auto *const ConstBase = cast<ConstantInt>(Baseline);
+ assert(ConstBase->getType()->isIntegerTy(1) &&
+ "Select condition is not an i1?");
+ ExpectedPredicate =
+ ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ;
+
+ // Remember the correct predicate.
+ Predicate = ExpectedPredicate;
} else {
+ // All the incoming values must be consistent.
+ if (Baseline != Val)
+ return std::nullopt;
// In this case, we expect a constant incoming value (the comparison is
// chained).
const auto *const Const = cast<ConstantInt>(Val);
- LLVM_DEBUG(dbgs() << "const\n");
- if (!Const->isZero())
+ assert(Const->getType()->isIntegerTy(1) &&
+ "Incoming value is not an i1?");
+ LLVM_DEBUG(dbgs() << "const i1 value\n");
+ if (!Const->isZero() && !Const->isOne())
return std::nullopt;
- LLVM_DEBUG(dbgs() << "false\n");
+ LLVM_DEBUG(dbgs() << *Const << "\n");
assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
Cond = BranchI->getCondition();
std::vector<ContiguousBlocks> MergedBlocks_;
// The original entry block (before sorting);
BasicBlock *EntryBlock_;
+ // Remember the predicate type of the chain.
+ ICmpInst::Predicate Predicate_;
};
static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
// Now look inside blocks to check for BCE comparisons.
std::vector<BCECmpBlock> Comparisons;
BaseIdentifier BaseId;
+ Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]);
+ Predicate_ = CmpInst::BAD_ICMP_PREDICATE;
for (BasicBlock *const Block : Blocks) {
assert(Block && "invalid block");
- std::optional<BCECmpBlock> Comparison = visitCmpBlock(
- Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId);
+ std::optional<BCECmpBlock> Comparison =
+ visitCmpBlock(Baseline, Predicate_, Phi.getIncomingValueForBlock(Block),
+ Block, Phi.getParent(), BaseId);
if (!Comparison) {
LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
return;
BasicBlock *const InsertBefore,
BasicBlock *const NextCmpBlock,
PHINode &Phi, const TargetLibraryInfo &TLI,
- AliasAnalysis &AA, DomTreeUpdater &DTU) {
+ AliasAnalysis &AA, DomTreeUpdater &DTU,
+ ICmpInst::Predicate Predicate) {
assert(!Comparisons.empty() && "merging zero comparisons");
LLVMContext &Context = NextCmpBlock->getContext();
const BCECmpBlock &FirstCmp = Comparisons[0];
else
Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
- Value *IsEqual = nullptr;
+ Value *ICmpValue = nullptr;
LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
<< BB->getName() << "\n");
ToSplit->split(BB, AA);
}
+ // For a Icmp chain, the Predicate is record the last link in the chain of
+ // comparisons. When we spilt the chain The new spilted chain of comparisons
+ // is end with ICMP_EQ.
+ // Only the last link in the chain is a unconditionla jmp.
+ BasicBlock *const TailBB = Comparisons[Comparisons.size() - 1].BB;
+ auto *const BranchI = dyn_cast<BranchInst>(TailBB->getTerminator());
+ ICmpInst::Predicate Pred =
+ BranchI->isUnconditional() ? Predicate : ICmpInst::ICMP_EQ;
if (Comparisons.size() == 1) {
LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n");
// Use clone to keep the metadata
LhsLoad->replaceUsesOfWith(LhsLoad->getOperand(0), Lhs);
RhsLoad->replaceUsesOfWith(RhsLoad->getOperand(0), Rhs);
// There are no blocks to merge, just do the comparison.
- IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
+ ICmpValue = Builder.CreateICmp(Pred, LhsLoad, RhsLoad);
} else {
const unsigned TotalSizeBits = std::accumulate(
Comparisons.begin(), Comparisons.end(), 0u,
Lhs, Rhs,
ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8),
Builder, DL, &TLI);
- IsEqual = Builder.CreateICmpEQ(
- MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
+ ICmpValue = Builder.CreateICmp(
+ Pred, MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0));
}
BasicBlock *const PhiBB = Phi.getParent();
if (NextCmpBlock == PhiBB) {
// Continue to phi, passing it the comparison result.
Builder.CreateBr(PhiBB);
- Phi.addIncoming(IsEqual, BB);
+ Phi.addIncoming(ICmpValue, BB);
DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
} else {
// Continue to next block if equal, exit to phi else.
- Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
- Phi.addIncoming(ConstantInt::getFalse(Context), BB);
+ Builder.CreateCondBr(ICmpValue, NextCmpBlock, PhiBB);
+ Value *ConstantVal = Predicate == CmpInst::ICMP_NE
+ ? ConstantInt::getTrue(Context)
+ : ConstantInt::getFalse(Context);
+ Phi.addIncoming(ConstantVal, BB);
DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
{DominatorTree::Insert, BB, PhiBB}});
}
// so that the next block is always available to branch to.
BasicBlock *InsertBefore = EntryBlock_;
BasicBlock *NextCmpBlock = Phi_.getParent();
+ assert(Predicate_ != CmpInst::BAD_ICMP_PREDICATE &&
+ "Got the chain of comparisons");
for (const auto &Blocks : reverse(MergedBlocks_)) {
InsertBefore = NextCmpBlock = mergeComparisons(
- Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
+ Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU, Predicate_);
}
// Replace the original cmp chain with the new cmp chain by pointing all
--- /dev/null
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=mergeicmps -verify-dom-info -S -mtriple=x86_64-unknown-unknown | FileCheck %s
+
+%struct.S = type { i8, i8, i8, i8 }
+%struct1.S = type { i32, i32, i32, i8 }
+%"struct.media::WebrtcVideoStatsDB::VideoDescKey" = type { i8, i32, i8, i32 }
+
+define noundef i1 @full_sequent_ne(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
+; CHECK-LABEL: @full_sequent_ne(
+; CHECK-NEXT: "bb0+bb1+bb2+bb3":
+; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[S0:%.*]], ptr [[S1:%.*]], i64 4)
+; CHECK-NEXT: [[TMP0:%.*]] = icmp ne i32 [[MEMCMP]], 0
+; CHECK-NEXT: br label [[BB4:%.*]]
+; CHECK: bb4:
+; CHECK-NEXT: ret i1 [[TMP0]]
+;
+bb0:
+ %v0 = load i8, ptr %s0, align 1
+ %v1 = load i8, ptr %s1, align 1
+ %cmp0 = icmp eq i8 %v0, %v1
+ br i1 %cmp0, label %bb1, label %bb4
+
+bb1: ; preds = %bb0
+ %s2 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
+ %v2 = load i8, ptr %s2, align 1
+ %s3 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
+ %v3 = load i8, ptr %s3, align 1
+ %cmp1 = icmp eq i8 %v2, %v3
+ br i1 %cmp1, label %bb2, label %bb4
+
+bb2: ; preds = %bb1
+ %s4 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 2
+ %v4 = load i8, ptr %s4, align 1
+ %s5 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 2
+ %v5 = load i8, ptr %s5, align 1
+ %cmp2 = icmp eq i8 %v4, %v5
+ br i1 %cmp2, label %bb3, label %bb4
+
+bb3: ; preds = %bb2
+ %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 3
+ %v6 = load i8, ptr %s6, align 1
+ %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 3
+ %v7 = load i8, ptr %s7, align 1
+ %cmp3 = icmp ne i8 %v6, %v7
+ br label %bb4
+
+bb4: ; preds = %bb0, %bb1, %bb2, %bb3
+ %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ true, %bb2 ], [ %cmp3, %bb3 ]
+ ret i1 %cmp
+}
+
+; Negative test: Incorrect const value in PHI node
+define noundef i1 @cmp_ne_incorrect_const(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) {
+; CHECK-LABEL: @cmp_ne_incorrect_const(
+; CHECK-NEXT: bb0:
+; CHECK-NEXT: [[V0:%.*]] = load i8, ptr [[S0:%.*]], align 1
+; CHECK-NEXT: [[V1:%.*]] = load i8, ptr [[S1:%.*]], align 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i8 [[V0]], [[V1]]
+; CHECK-NEXT: br i1 [[CMP0]], label [[BB1:%.*]], label [[BB2:%.*]]
+; CHECK: bb1:
+; CHECK-NEXT: [[S6:%.*]] = getelementptr inbounds [[STRUCT_S:%.*]], ptr [[S0]], i64 0, i32 1
+; CHECK-NEXT: [[V6:%.*]] = load i8, ptr [[S6]], align 1
+; CHECK-NEXT: [[S7:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[S1]], i64 0, i32 1
+; CHECK-NEXT: [[V7:%.*]] = load i8, ptr [[S7]], align 1
+; CHECK-NEXT: [[CMP3:%.*]] = icmp ne i8 [[V6]], [[V7]]
+; CHECK-NEXT: br label [[BB2]]
+; CHECK: bb2:
+; CHECK-NEXT: [[CMP:%.*]] = phi i1 [ false, [[BB0:%.*]] ], [ [[CMP3]], [[BB1]] ]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+bb0:
+ %v0 = load i8, ptr %s0, align 1
+ %v1 = load i8, ptr %s1, align 1
+ %cmp0 = icmp eq i8 %v0, %v1
+ br i1 %cmp0, label %bb1, label %bb2
+
+bb1: ; preds = %bb0
+ %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1
+ %v6 = load i8, ptr %s6, align 1
+ %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1
+ %v7 = load i8, ptr %s7, align 1
+ %cmp3 = icmp ne i8 %v6, %v7
+ br label %bb2
+
+bb2: ; preds = %bb0, %bb1
+ %cmp = phi i1 [ false, %bb0 ], [ %cmp3, %bb1 ]
+ ret i1 %cmp
+}
+
+; https://alive2.llvm.org/ce/z/Zi2Z3Y
+define noundef i1 @partial_sequent_eq(ptr nocapture readonly dereferenceable(16) %s0, ptr nocapture readonly dereferenceable(16) %s1) {
+; CHECK-LABEL: @partial_sequent_eq(
+; CHECK-NEXT: bb01:
+; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[S0:%.*]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr [[S1:%.*]], align 8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: br i1 [[TMP2]], label %"bb1+bb2", label [[BB3:%.*]]
+; CHECK: "bb1+bb2":
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [[STRUCT1_S:%.*]], ptr [[S0]], i64 0, i32 2
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT1_S]], ptr [[S1]], i64 0, i32 2
+; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[TMP3]], ptr [[TMP4]], i64 5)
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i32 [[MEMCMP]], 0
+; CHECK-NEXT: br label [[BB3]]
+; CHECK: bb3:
+; CHECK-NEXT: [[CMP:%.*]] = phi i1 [ [[TMP5]], %"bb1+bb2" ], [ false, [[BB01:%.*]] ]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+bb0:
+ %v0 = load i32, ptr %s0, align 8
+ %v1 = load i32, ptr %s1, align 8
+ %cmp0 = icmp eq i32 %v0, %v1
+ br i1 %cmp0, label %bb1, label %bb3
+
+bb1: ; preds = %bb0
+ %s2 = getelementptr inbounds %struct1.S, ptr %s0, i64 0, i32 2
+ %v2 = load i32, ptr %s2, align 8
+ %s3 = getelementptr inbounds %struct1.S, ptr %s1, i64 0, i32 2
+ %v3 = load i32, ptr %s3, align 8
+ %cmp1 = icmp eq i32 %v2, %v3
+ br i1 %cmp1, label %bb2, label %bb3
+
+bb2: ; preds = %bb2
+ %s6 = getelementptr inbounds %struct1.S, ptr %s0, i64 0, i32 3
+ %v6 = load i8, ptr %s6, align 1
+ %s7 = getelementptr inbounds %struct1.S, ptr %s1, i64 0, i32 3
+ %v7 = load i8, ptr %s7, align 1
+ %cmp3 = icmp eq i8 %v6, %v7
+ br label %bb3
+
+bb3: ; preds = %bb0, %bb1, %bb2
+ %cmp = phi i1 [ false, %bb0 ], [ false, %bb1 ], [ %cmp3, %bb2 ]
+ ret i1 %cmp
+}
+
+; https://alive2.llvm.org/ce/z/sL5Uz6
+define noundef i1 @partial_sequent_ne(ptr nocapture readonly dereferenceable(16) %s0, ptr nocapture readonly dereferenceable(16) %s1) {
+; CHECK-LABEL: @partial_sequent_ne(
+; CHECK-NEXT: bb01:
+; CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[S0:%.*]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr [[S1:%.*]], align 8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: br i1 [[TMP2]], label %"bb1+bb2", label [[BB3:%.*]]
+; CHECK: "bb1+bb2":
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [[STRUCT1_S:%.*]], ptr [[S0]], i64 0, i32 2
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT1_S]], ptr [[S1]], i64 0, i32 2
+; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[TMP3]], ptr [[TMP4]], i64 5)
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ne i32 [[MEMCMP]], 0
+; CHECK-NEXT: br label [[BB3]]
+; CHECK: bb3:
+; CHECK-NEXT: [[CMP:%.*]] = phi i1 [ [[TMP5]], %"bb1+bb2" ], [ true, [[BB01:%.*]] ]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+bb0:
+ %v0 = load i32, ptr %s0, align 8
+ %v1 = load i32, ptr %s1, align 8
+ %cmp0 = icmp eq i32 %v0, %v1
+ br i1 %cmp0, label %bb1, label %bb3
+
+bb1: ; preds = %bb0
+ %s2 = getelementptr inbounds %struct1.S, ptr %s0, i64 0, i32 2
+ %v2 = load i32, ptr %s2, align 8
+ %s3 = getelementptr inbounds %struct1.S, ptr %s1, i64 0, i32 2
+ %v3 = load i32, ptr %s3, align 8
+ %cmp1 = icmp eq i32 %v2, %v3
+ br i1 %cmp1, label %bb2, label %bb3
+
+bb2: ; preds = %bb2
+ %s6 = getelementptr inbounds %struct1.S, ptr %s0, i64 0, i32 3
+ %v6 = load i8, ptr %s6, align 1
+ %s7 = getelementptr inbounds %struct1.S, ptr %s1, i64 0, i32 3
+ %v7 = load i8, ptr %s7, align 1
+ %cmp3 = icmp ne i8 %v6, %v7
+ br label %bb3
+
+bb3: ; preds = %bb0, %bb1, %bb2
+ %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ %cmp3, %bb2 ]
+ ret i1 %cmp
+}
+
+; https://alive2.llvm.org/ce/z/EQtb_S
+define i1 @WebrtcVideoStats(ptr nocapture noundef dereferenceable(16) %S0, ptr nocapture noundef dereferenceable(16) %S1) {
+; CHECK-LABEL: @WebrtcVideoStats(
+; CHECK-NEXT: bb02:
+; CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[S0:%.*]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[S1:%.*]], align 4
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: br i1 [[TMP2]], label %"bb1+bb2", label [[BB4:%.*]]
+; CHECK: "bb1+bb2":
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr [[S0]], i64 0, i32 1
+; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr [[S1]], i64 0, i32 1
+; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[TMP3]], ptr [[TMP4]], i64 5)
+; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i32 [[MEMCMP]], 0
+; CHECK-NEXT: br i1 [[TMP5]], label [[BB31:%.*]], label [[BB4]]
+; CHECK: bb31:
+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr [[S0]], i64 0, i32 3
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr [[S1]], i64 0, i32 3
+; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr [[TMP6]], align 4
+; CHECK-NEXT: [[TMP9:%.*]] = load i32, ptr [[TMP7]], align 4
+; CHECK-NEXT: [[TMP10:%.*]] = icmp ne i32 [[TMP8]], [[TMP9]]
+; CHECK-NEXT: br label [[BB4]]
+; CHECK: bb4:
+; CHECK-NEXT: [[RESULT:%.*]] = phi i1 [ [[TMP10]], [[BB31]] ], [ true, %"bb1+bb2" ], [ true, [[BB02:%.*]] ]
+; CHECK-NEXT: ret i1 [[RESULT]]
+;
+bb0:
+ %V0 = load i8, ptr %S0, align 4
+ %V1 = load i8, ptr %S1, align 4
+ %Cmp0 = icmp eq i8 %V0, %V1
+ br i1 %Cmp0, label %bb1, label %bb4
+
+bb1: ; preds = %bb0
+ %Base2 = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr %S0, i64 0, i32 1
+ %V2 = load i32, ptr %Base2, align 4
+ %Base3 = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr %S1, i64 0, i32 1
+ %V3 = load i32, ptr %Base3, align 4
+ %Cmp1 = icmp eq i32 %V2, %V3
+ br i1 %Cmp1, label %bb2, label %bb4
+
+bb2: ; preds = %bb1
+ %Base4 = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr %S0, i64 0, i32 2
+ %V4 = load i8, ptr %Base4, align 4
+ %Base5 = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr %S1, i64 0, i32 2
+ %V5 = load i8, ptr %Base5, align 4
+ %Cmp2 = icmp eq i8 %V4, %V5
+ br i1 %Cmp2, label %bb3, label %bb4
+
+bb3: ; preds = %bb2
+ %Base6 = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr %S0, i64 0, i32 3
+ %V6 = load i32, ptr %Base6, align 4
+ %Base7 = getelementptr inbounds %"struct.media::WebrtcVideoStatsDB::VideoDescKey", ptr %S1, i64 0, i32 3
+ %V7 = load i32, ptr %Base7, align 4
+ %Cmp3 = icmp ne i32 %V6, %V7
+ br label %bb4
+
+bb4: ; preds = %bb3, %bb2, %bb1, %bb0
+ %result = phi i1 [ %Cmp3, %bb3 ], [ true, %bb2 ], [ true, %bb1 ], [ true, %bb0 ]
+ ret i1 %result
+}