From a3f4f0828bb791d2b7b396ccc61a95f4b0e76ba6 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Tue, 16 Aug 2016 17:54:36 +0000 Subject: [PATCH] [InstCombine] add helper functions for foldICmpWithConstant; NFCI Besides breaking up a 700 line function to improve readability, this sinks the 'FIXME: ConstantInt' check into each helper. So now we can independently break that restriction within any of the helper functions. As much as possible, the code was only {cut/paste/clang-format}'ed to minimize risk (no functional changes intended), so several more readability improvements are still possible. llvm-svn: 278828 --- .../Transforms/InstCombine/InstCombineCompares.cpp | 1283 +++++++++++--------- .../Transforms/InstCombine/InstCombineInternal.h | 32 +- 2 files changed, 726 insertions(+), 589 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 95dab35..4cbb22a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1179,8 +1179,9 @@ Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, /// Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS and CmpRHS are /// both known to be integer constants. -Instruction *InstCombiner::foldICmpDivConst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS) { +Instruction *InstCombiner::foldICmpDivConstConst(ICmpInst &ICI, + BinaryOperator *DivI, + ConstantInt *DivRHS) { ConstantInt *CmpRHS = cast(ICI.getOperand(1)); const APInt &CmpRHSV = CmpRHS->getValue(); @@ -1335,8 +1336,9 @@ Instruction *InstCombiner::foldICmpDivConst(ICmpInst &ICI, BinaryOperator *DivI, } /// Handle "icmp(([al]shr X, cst1), cst2)". -Instruction *InstCombiner::foldICmpShrConst(ICmpInst &ICI, BinaryOperator *Shr, - ConstantInt *ShAmt) { +Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &ICI, + BinaryOperator *Shr, + ConstantInt *ShAmt) { const APInt &CmpRHSV = cast(ICI.getOperand(1))->getValue(); // Check that the shift amount is in range. If not, don't perform @@ -1382,7 +1384,8 @@ Instruction *InstCombiner::foldICmpShrConst(ICmpInst &ICI, BinaryOperator *Shr, assert(TheDiv->getOpcode() == Instruction::SDiv || TheDiv->getOpcode() == Instruction::UDiv); - Instruction *Res = foldICmpDivConst(ICI, TheDiv, cast(DivCst)); + Instruction *Res = + foldICmpDivConstConst(ICI, TheDiv, cast(DivCst)); assert(Res && "This div/cst should have folded!"); return Res; } @@ -1530,672 +1533,782 @@ Instruction *InstCombiner::foldICmpCstShlConst(ICmpInst &I, Value *Op, Value *A, return getConstant(false); } -/// Try to fold integer comparisons with a constant operand: icmp Pred X, C. -Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI) { - Instruction *LHSI; - const APInt *RHSV; - if (!match(ICI.getOperand(0), m_Instruction(LHSI)) || - !match(ICI.getOperand(1), m_APInt(RHSV))) - return nullptr; - +Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &ICI, + Instruction *LHSI, + const APInt *RHSV) { // FIXME: This check restricts all folds under here to scalar types. ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); if (!RHS) return nullptr; - switch (LHSI->getOpcode()) { - case Instruction::Trunc: - if (RHS->isOne() && RHSV->getBitWidth() > 1) { - // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI->getOperand(0), m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); - } - if (ICI.isEquality() && LHSI->hasOneUse()) { - // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all - // of the high bits truncated out of x are known. - unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), + if (RHS->isOne() && RHSV->getBitWidth() > 1) { + // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (ICI.getPredicate() == ICmpInst::ICMP_SLT && + match(LHSI->getOperand(0), m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + if (ICI.isEquality() && LHSI->hasOneUse()) { + // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all + // of the high bits truncated out of x are known. + unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); - APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); - - // If all the high bits are known, we can do this xform. - if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { - // Pull in the high bits from known-ones set. - APInt NewRHS = RHS->getValue().zext(SrcBits); - NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits-DstBits); - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - Builder->getInt(NewRHS)); - } + APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); + + // If all the high bits are known, we can do this xform. + if ((KnownZero | KnownOne).countLeadingOnes() >= SrcBits - DstBits) { + // Pull in the high bits from known-ones set. + APInt NewRHS = RHS->getValue().zext(SrcBits); + NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + Builder->getInt(NewRHS)); } - break; + } + return nullptr; +} - case Instruction::Xor: // (icmp pred (xor X, XorCst), CI) - if (ConstantInt *XorCst = dyn_cast(LHSI->getOperand(1))) { - // If this is a comparison that tests the signbit (X < 0) or (x > -1), - // fold the xor. - if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && *RHSV == 0) || - (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV->isAllOnesValue())) { - Value *CompareVal = LHSI->getOperand(0); - - // If the sign bit of the XorCst is not set, there is no change to - // the operation, just stop using the Xor. - if (!XorCst->isNegative()) { - ICI.setOperand(0, CompareVal); - Worklist.Add(LHSI); - return &ICI; - } +Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; - // Was the old condition true if the operand is positive? - bool isTrueIfPositive = ICI.getPredicate() == ICmpInst::ICMP_SGT; + if (ConstantInt *XorCst = dyn_cast(LHSI->getOperand(1))) { + // If this is a comparison that tests the signbit (X < 0) or (x > -1), + // fold the xor. + if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && *RHSV == 0) || + (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV->isAllOnesValue())) { + Value *CompareVal = LHSI->getOperand(0); + + // If the sign bit of the XorCst is not set, there is no change to + // the operation, just stop using the Xor. + if (!XorCst->isNegative()) { + ICI.setOperand(0, CompareVal); + Worklist.Add(LHSI); + return &ICI; + } - // If so, the new one isn't. - isTrueIfPositive ^= true; + // Was the old condition true if the operand is positive? + bool isTrueIfPositive = ICI.getPredicate() == ICmpInst::ICMP_SGT; - if (isTrueIfPositive) - return new ICmpInst(ICmpInst::ICMP_SGT, CompareVal, - SubOne(RHS)); - else - return new ICmpInst(ICmpInst::ICMP_SLT, CompareVal, - AddOne(RHS)); - } + // If so, the new one isn't. + isTrueIfPositive ^= true; - if (LHSI->hasOneUse()) { - // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) - if (!ICI.isEquality() && XorCst->getValue().isSignBit()) { - const APInt &SignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(*RHSV ^ SignBit)); - } + if (isTrueIfPositive) + return new ICmpInst(ICmpInst::ICMP_SGT, CompareVal, SubOne(RHS)); + else + return new ICmpInst(ICmpInst::ICMP_SLT, CompareVal, AddOne(RHS)); + } - // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) - if (!ICI.isEquality() && XorCst->isMaxValue(true)) { - const APInt &NotSignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - Pred = ICI.getSwappedPredicate(Pred); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(*RHSV ^ NotSignBit)); - } + if (LHSI->hasOneUse()) { + // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) + if (!ICI.isEquality() && XorCst->getValue().isSignBit()) { + const APInt &SignBit = XorCst->getValue(); + ICmpInst::Predicate Pred = ICI.isSigned() ? ICI.getUnsignedPredicate() + : ICI.getSignedPredicate(); + return new ICmpInst(Pred, LHSI->getOperand(0), + Builder->getInt(*RHSV ^ SignBit)); } - // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && - XorCst->getValue() == ~(*RHSV) && (*RHSV + 1).isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); - - // (icmp ult (xor X, C), -C) -> (icmp uge X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && - XorCst->getValue() == -(*RHSV) && RHSV->isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); + // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) + if (!ICI.isEquality() && XorCst->isMaxValue(true)) { + const APInt &NotSignBit = XorCst->getValue(); + ICmpInst::Predicate Pred = ICI.isSigned() ? ICI.getUnsignedPredicate() + : ICI.getSignedPredicate(); + Pred = ICI.getSwappedPredicate(Pred); + return new ICmpInst(Pred, LHSI->getOperand(0), + Builder->getInt(*RHSV ^ NotSignBit)); + } } - break; - case Instruction::And: // (icmp pred (and X, AndCst), RHS) - if (LHSI->hasOneUse() && isa(LHSI->getOperand(1)) && - LHSI->getOperand(0)->hasOneUse()) { - ConstantInt *AndCst = cast(LHSI->getOperand(1)); - - // If the LHS is an AND of a truncating cast, we can widen the - // and/compare to be the input width without changing the value - // produced, eliminating a cast. - if (TruncInst *Cast = dyn_cast(LHSI->getOperand(0))) { - // We can do this transformation if either the AND constant does not - // have its sign bit set or if it is an equality comparison. - // Extending a relational comparison when we're checking the sign - // bit would not work. - if (ICI.isEquality() || - (!AndCst->isNegative() && RHSV->isNonNegative())) { - Value *NewAnd = + + // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) + // iff -C is a power of 2 + if (ICI.getPredicate() == ICmpInst::ICMP_UGT && + XorCst->getValue() == ~(*RHSV) && (*RHSV + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); + + // (icmp ult (xor X, C), -C) -> (icmp uge X, C) + // iff -C is a power of 2 + if (ICI.getPredicate() == ICmpInst::ICMP_ULT && + XorCst->getValue() == -(*RHSV) && RHSV->isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); + } + return nullptr; +} + +Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; + + if (LHSI->hasOneUse() && isa(LHSI->getOperand(1)) && + LHSI->getOperand(0)->hasOneUse()) { + ConstantInt *AndCst = cast(LHSI->getOperand(1)); + + // If the LHS is an AND of a truncating cast, we can widen the + // and/compare to be the input width without changing the value + // produced, eliminating a cast. + if (TruncInst *Cast = dyn_cast(LHSI->getOperand(0))) { + // We can do this transformation if either the AND constant does not + // have its sign bit set or if it is an equality comparison. + // Extending a relational comparison when we're checking the sign + // bit would not work. + if (ICI.isEquality() || + (!AndCst->isNegative() && RHSV->isNonNegative())) { + Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), ConstantExpr::getZExt(AndCst, Cast->getSrcTy())); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getZExt(RHS, Cast->getSrcTy())); - } + NewAnd->takeName(LHSI); + return new ICmpInst(ICI.getPredicate(), NewAnd, + ConstantExpr::getZExt(RHS, Cast->getSrcTy())); } + } - // If the LHS is an AND of a zext, and we have an equality compare, we can - // shrink the and/compare to the smaller type, eliminating the cast. - if (ZExtInst *Cast = dyn_cast(LHSI->getOperand(0))) { - IntegerType *Ty = cast(Cast->getSrcTy()); - // Make sure we don't compare the upper bits, SimplifyDemandedBits - // should fold the icmp to true/false in that case. - if (ICI.isEquality() && RHSV->getActiveBits() <= Ty->getBitWidth()) { - Value *NewAnd = - Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getTrunc(AndCst, Ty)); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getTrunc(RHS, Ty)); - } + // If the LHS is an AND of a zext, and we have an equality compare, we can + // shrink the and/compare to the smaller type, eliminating the cast. + if (ZExtInst *Cast = dyn_cast(LHSI->getOperand(0))) { + IntegerType *Ty = cast(Cast->getSrcTy()); + // Make sure we don't compare the upper bits, SimplifyDemandedBits + // should fold the icmp to true/false in that case. + if (ICI.isEquality() && RHSV->getActiveBits() <= Ty->getBitWidth()) { + Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), + ConstantExpr::getTrunc(AndCst, Ty)); + NewAnd->takeName(LHSI); + return new ICmpInst(ICI.getPredicate(), NewAnd, + ConstantExpr::getTrunc(RHS, Ty)); } + } - // If this is: (X >> C1) & C2 != C3 (where any shift and any compare - // could exist), turn it into (X & (C2 << C1)) != (C3 << C1). This - // happens a LOT in code produced by the C front-end, for bitfield - // access. - BinaryOperator *Shift = dyn_cast(LHSI->getOperand(0)); - if (Shift && !Shift->isShift()) - Shift = nullptr; - - ConstantInt *ShAmt; - ShAmt = Shift ? dyn_cast(Shift->getOperand(1)) : nullptr; - - // This seemingly simple opportunity to fold away a shift turns out to - // be rather complicated. See PR17827 - // ( http://llvm.org/bugs/show_bug.cgi?id=17827 ) for details. - if (ShAmt) { - bool CanFold = false; - unsigned ShiftOpcode = Shift->getOpcode(); - if (ShiftOpcode == Instruction::AShr) { - // There may be some constraints that make this possible, - // but nothing simple has been discovered yet. - CanFold = false; - } else if (ShiftOpcode == Instruction::Shl) { - // For a left shift, we can fold if the comparison is not signed. - // We can also fold a signed comparison if the mask value and - // comparison value are not negative. These constraints may not be - // obvious, but we can prove that they are correct using an SMT - // solver. - if (!ICI.isSigned() || (!AndCst->isNegative() && !RHS->isNegative())) - CanFold = true; - } else if (ShiftOpcode == Instruction::LShr) { - // For a logical right shift, we can fold if the comparison is not - // signed. We can also fold a signed comparison if the shifted mask - // value and the shifted comparison value are not negative. - // These constraints may not be obvious, but we can prove that they - // are correct using an SMT solver. - if (!ICI.isSigned()) - CanFold = true; - else { - ConstantInt *ShiftedAndCst = + // If this is: (X >> C1) & C2 != C3 (where any shift and any compare + // could exist), turn it into (X & (C2 << C1)) != (C3 << C1). This + // happens a LOT in code produced by the C front-end, for bitfield + // access. + BinaryOperator *Shift = dyn_cast(LHSI->getOperand(0)); + if (Shift && !Shift->isShift()) + Shift = nullptr; + + ConstantInt *ShAmt; + ShAmt = Shift ? dyn_cast(Shift->getOperand(1)) : nullptr; + + // This seemingly simple opportunity to fold away a shift turns out to + // be rather complicated. See PR17827 + // ( http://llvm.org/bugs/show_bug.cgi?id=17827 ) for details. + if (ShAmt) { + bool CanFold = false; + unsigned ShiftOpcode = Shift->getOpcode(); + if (ShiftOpcode == Instruction::AShr) { + // There may be some constraints that make this possible, + // but nothing simple has been discovered yet. + CanFold = false; + } else if (ShiftOpcode == Instruction::Shl) { + // For a left shift, we can fold if the comparison is not signed. + // We can also fold a signed comparison if the mask value and + // comparison value are not negative. These constraints may not be + // obvious, but we can prove that they are correct using an SMT + // solver. + if (!ICI.isSigned() || (!AndCst->isNegative() && !RHS->isNegative())) + CanFold = true; + } else if (ShiftOpcode == Instruction::LShr) { + // For a logical right shift, we can fold if the comparison is not + // signed. We can also fold a signed comparison if the shifted mask + // value and the shifted comparison value are not negative. + // These constraints may not be obvious, but we can prove that they + // are correct using an SMT solver. + if (!ICI.isSigned()) + CanFold = true; + else { + ConstantInt *ShiftedAndCst = cast(ConstantExpr::getShl(AndCst, ShAmt)); - ConstantInt *ShiftedRHSCst = + ConstantInt *ShiftedRHSCst = cast(ConstantExpr::getShl(RHS, ShAmt)); - if (!ShiftedAndCst->isNegative() && !ShiftedRHSCst->isNegative()) - CanFold = true; - } + if (!ShiftedAndCst->isNegative() && !ShiftedRHSCst->isNegative()) + CanFold = true; } + } - if (CanFold) { - Constant *NewCst; + if (CanFold) { + Constant *NewCst; + if (ShiftOpcode == Instruction::Shl) + NewCst = ConstantExpr::getLShr(RHS, ShAmt); + else + NewCst = ConstantExpr::getShl(RHS, ShAmt); + + // Check to see if we are shifting out any of the bits being + // compared. + if (ConstantExpr::get(ShiftOpcode, NewCst, ShAmt) != RHS) { + // If we shifted bits out, the fold is not going to work out. + // As a special case, check to see if this means that the + // result is always true or false now. + if (ICI.getPredicate() == ICmpInst::ICMP_EQ) + return replaceInstUsesWith(ICI, Builder->getFalse()); + if (ICI.getPredicate() == ICmpInst::ICMP_NE) + return replaceInstUsesWith(ICI, Builder->getTrue()); + } else { + ICI.setOperand(1, NewCst); + Constant *NewAndCst; if (ShiftOpcode == Instruction::Shl) - NewCst = ConstantExpr::getLShr(RHS, ShAmt); + NewAndCst = ConstantExpr::getLShr(AndCst, ShAmt); else - NewCst = ConstantExpr::getShl(RHS, ShAmt); - - // Check to see if we are shifting out any of the bits being - // compared. - if (ConstantExpr::get(ShiftOpcode, NewCst, ShAmt) != RHS) { - // If we shifted bits out, the fold is not going to work out. - // As a special case, check to see if this means that the - // result is always true or false now. - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return replaceInstUsesWith(ICI, Builder->getTrue()); - } else { - ICI.setOperand(1, NewCst); - Constant *NewAndCst; - if (ShiftOpcode == Instruction::Shl) - NewAndCst = ConstantExpr::getLShr(AndCst, ShAmt); - else - NewAndCst = ConstantExpr::getShl(AndCst, ShAmt); - LHSI->setOperand(1, NewAndCst); - LHSI->setOperand(0, Shift->getOperand(0)); - Worklist.Add(Shift); // Shift is dead. - return &ICI; - } + NewAndCst = ConstantExpr::getShl(AndCst, ShAmt); + LHSI->setOperand(1, NewAndCst); + LHSI->setOperand(0, Shift->getOperand(0)); + Worklist.Add(Shift); // Shift is dead. + return &ICI; } } + } - // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is - // preferable because it allows the C<hasOneUse() && *RHSV == 0 && - ICI.isEquality() && !Shift->isArithmeticShift() && - !isa(Shift->getOperand(0))) { - // Compute C << Y. - Value *NS; - if (Shift->getOpcode() == Instruction::LShr) { - NS = Builder->CreateShl(AndCst, Shift->getOperand(1)); - } else { - // Insert a logical shift. - NS = Builder->CreateLShr(AndCst, Shift->getOperand(1)); - } + // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is + // preferable because it allows the C<hasOneUse() && *RHSV == 0 && ICI.isEquality() && + !Shift->isArithmeticShift() && !isa(Shift->getOperand(0))) { + // Compute C << Y. + Value *NS; + if (Shift->getOpcode() == Instruction::LShr) { + NS = Builder->CreateShl(AndCst, Shift->getOperand(1)); + } else { + // Insert a logical shift. + NS = Builder->CreateLShr(AndCst, Shift->getOperand(1)); + } - // Compute X & (C << Y). - Value *NewAnd = + // Compute X & (C << Y). + Value *NewAnd = Builder->CreateAnd(Shift->getOperand(0), NS, LHSI->getName()); - ICI.setOperand(0, NewAnd); - return &ICI; - } + ICI.setOperand(0, NewAnd); + return &ICI; + } - // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> - // (icmp pred (and X, (or (shl 1, Y), 1), 0)) - // - // iff pred isn't signed - { - Value *X, *Y, *LShr; - if (!ICI.isSigned() && *RHSV == 0) { - if (match(LHSI->getOperand(1), m_One())) { - Constant *One = cast(LHSI->getOperand(1)); - Value *Or = LHSI->getOperand(0); - if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && - match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { - unsigned UsesRemoved = 0; - if (LHSI->hasOneUse()) - ++UsesRemoved; - if (Or->hasOneUse()) - ++UsesRemoved; - if (LShr->hasOneUse()) - ++UsesRemoved; - Value *NewOr = nullptr; - // Compute X & ((1 << Y) | 1) - if (auto *C = dyn_cast(Y)) { - if (UsesRemoved >= 1) - NewOr = - ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); - } else { - if (UsesRemoved >= 3) - NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, - LShr->getName(), - /*HasNUW=*/true), - One, Or->getName()); - } - if (NewOr) { - Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); - ICI.setOperand(0, NewAnd); - return &ICI; - } + // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> + // (icmp pred (and X, (or (shl 1, Y), 1), 0)) + // + // iff pred isn't signed + { + Value *X, *Y, *LShr; + if (!ICI.isSigned() && *RHSV == 0) { + if (match(LHSI->getOperand(1), m_One())) { + Constant *One = cast(LHSI->getOperand(1)); + Value *Or = LHSI->getOperand(0); + if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && + match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { + unsigned UsesRemoved = 0; + if (LHSI->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + Value *NewOr = nullptr; + // Compute X & ((1 << Y) | 1) + if (auto *C = dyn_cast(Y)) { + if (UsesRemoved >= 1) + NewOr = + ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, + LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); + ICI.setOperand(0, NewAnd); + return &ICI; } } } } - - // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any - // bit set in (X & AndCst) will produce a result greater than RHSV. - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - unsigned NTZ = AndCst->getValue().countTrailingZeros(); - if ((NTZ < AndCst->getBitWidth()) && - APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(*RHSV)) - return new ICmpInst(ICmpInst::ICMP_NE, LHSI, - Constant::getNullValue(RHS->getType())); - } } - // Try to optimize things like "A[i]&42 == 0" to index computations. - if (LoadInst *LI = dyn_cast(LHSI->getOperand(0))) { - if (GetElementPtrInst *GEP = - dyn_cast(LI->getOperand(0))) - if (GlobalVariable *GV = dyn_cast(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !LI->isVolatile() && isa(LHSI->getOperand(1))) { - ConstantInt *C = cast(LHSI->getOperand(1)); - if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV,ICI, C)) - return Res; - } + // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any + // bit set in (X & AndCst) will produce a result greater than RHSV. + if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { + unsigned NTZ = AndCst->getValue().countTrailingZeros(); + if ((NTZ < AndCst->getBitWidth()) && + APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(*RHSV)) + return new ICmpInst(ICmpInst::ICMP_NE, LHSI, + Constant::getNullValue(RHS->getType())); } + } - // X & -C == -C -> X > u ~C - // X & -C != -C -> X <= u ~C - // iff C is a power of 2 - if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-(*RHSV)).isPowerOf2()) - return new ICmpInst( - ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT - : ICmpInst::ICMP_ULE, - LHSI->getOperand(0), SubOne(RHS)); - - // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1) - // iff C is a power of 2 - if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) { - if (auto *CI = dyn_cast(LHSI->getOperand(1))) { - const APInt &AI = CI->getValue(); - int32_t ExactLogBase2 = AI.exactLogBase2(); - if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { - Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1); - Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy); - return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_SGE - : ICmpInst::ICMP_SLT, - Trunc, Constant::getNullValue(NTy)); + // Try to optimize things like "A[i]&42 == 0" to index computations. + if (LoadInst *LI = dyn_cast(LHSI->getOperand(0))) { + if (GetElementPtrInst *GEP = dyn_cast(LI->getOperand(0))) + if (GlobalVariable *GV = dyn_cast(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !LI->isVolatile() && isa(LHSI->getOperand(1))) { + ConstantInt *C = cast(LHSI->getOperand(1)); + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, ICI, C)) + return Res; } + } + + // X & -C == -C -> X > u ~C + // X & -C != -C -> X <= u ~C + // iff C is a power of 2 + if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-(*RHSV)).isPowerOf2()) + return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_UGT + : ICmpInst::ICMP_ULE, + LHSI->getOperand(0), SubOne(RHS)); + + // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1) + // iff C is a power of 2 + if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) { + if (auto *CI = dyn_cast(LHSI->getOperand(1))) { + const APInt &AI = CI->getValue(); + int32_t ExactLogBase2 = AI.exactLogBase2(); + if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { + Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1); + Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy); + return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_SLT, + Trunc, Constant::getNullValue(NTy)); } } - break; + } + return nullptr; +} - case Instruction::Or: { - if (RHS->isOne()) { - // icmp slt signum(V) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI, m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); - } +Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; - if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse()) - break; - Value *P, *Q; - if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { - // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 - // -> and (icmp eq P, null), (icmp eq Q, null). - Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, - Constant::getNullValue(P->getType())); - Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, - Constant::getNullValue(Q->getType())); - Instruction *Op; - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - Op = BinaryOperator::CreateAnd(ICIP, ICIQ); - else - Op = BinaryOperator::CreateOr(ICIP, ICIQ); - return Op; - } - break; + if (RHS->isOne()) { + // icmp slt signum(V) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (ICI.getPredicate() == ICmpInst::ICMP_SLT && + match(LHSI, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); } - case Instruction::Mul: { // (icmp pred (mul X, Val), CI) - ConstantInt *Val = dyn_cast(LHSI->getOperand(1)); - if (!Val) break; - - // If this is a signed comparison to 0 and the mul is sign preserving, - // use the mul LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && !Val->isZero() && - cast(LHSI)->hasNoSignedWrap()) - return new ICmpInst(Val->isNegative() ? - ICmpInst::getSwappedPredicate(pred) : pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); + if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse()) + return nullptr; - break; + Value *P, *Q; + if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { + // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 + // -> and (icmp eq P, null), (icmp eq Q, null). + Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, + Constant::getNullValue(P->getType())); + Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, + Constant::getNullValue(Q->getType())); + Instruction *Op; + if (ICI.getPredicate() == ICmpInst::ICMP_EQ) + Op = BinaryOperator::CreateAnd(ICIP, ICIQ); + else + Op = BinaryOperator::CreateOr(ICIP, ICIQ); + return Op; } + return nullptr; +} - case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) - uint32_t TypeBits = RHSV->getBitWidth(); - ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); - if (!ShAmt) { - Value *X; - // (1 << X) pred P2 -> X pred Log2(P2) - if (match(LHSI, m_Shl(m_One(), m_Value(X)))) { - bool RHSVIsPowerOf2 = RHSV->isPowerOf2(); - ICmpInst::Predicate Pred = ICI.getPredicate(); - if (ICI.isUnsigned()) { - if (!RHSVIsPowerOf2) { - // (1 << X) < 30 -> X <= 4 - // (1 << X) <= 30 -> X <= 4 - // (1 << X) >= 30 -> X > 4 - // (1 << X) > 30 -> X > 4 - if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_ULE; - else if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_UGT; - } - unsigned RHSLog2 = RHSV->logBase2(); - - // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) < 2147483648 -> X < 31 -> X != 31 - if (RHSLog2 == TypeBits-1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } +Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; - return new ICmpInst(Pred, X, - ConstantInt::get(RHS->getType(), RHSLog2)); - } else if (ICI.isSigned()) { - if (RHSV->isAllOnesValue()) { - // (1 << X) <= -1 -> X == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - - // (1 << X) > -1 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } else if (!(*RHSV)) { - // (1 << X) < 0 -> X == 31 - // (1 << X) <= 0 -> X == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - - // (1 << X) >= 0 -> X != 31 - // (1 << X) > 0 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } - } else if (ICI.isEquality()) { - if (RHSVIsPowerOf2) - return new ICmpInst( - Pred, X, ConstantInt::get(RHS->getType(), RHSV->logBase2())); + ConstantInt *Val = dyn_cast(LHSI->getOperand(1)); + if (!Val) + return nullptr; + + // If this is a signed comparison to 0 and the mul is sign preserving, + // use the mul LHS operand instead. + ICmpInst::Predicate pred = ICI.getPredicate(); + if (isSignTest(pred, RHS) && !Val->isZero() && + cast(LHSI)->hasNoSignedWrap()) + return new ICmpInst(Val->isNegative() ? + ICmpInst::getSwappedPredicate(pred) : pred, + LHSI->getOperand(0), + Constant::getNullValue(RHS->getType())); + + return nullptr; +} + +Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; + + uint32_t TypeBits = RHSV->getBitWidth(); + ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); + if (!ShAmt) { + Value *X; + // (1 << X) pred P2 -> X pred Log2(P2) + if (match(LHSI, m_Shl(m_One(), m_Value(X)))) { + bool RHSVIsPowerOf2 = RHSV->isPowerOf2(); + ICmpInst::Predicate Pred = ICI.getPredicate(); + if (ICI.isUnsigned()) { + if (!RHSVIsPowerOf2) { + // (1 << X) < 30 -> X <= 4 + // (1 << X) <= 30 -> X <= 4 + // (1 << X) >= 30 -> X > 4 + // (1 << X) > 30 -> X > 4 + if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_ULE; + else if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_UGT; + } + unsigned RHSLog2 = RHSV->logBase2(); + + // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 + // (1 << X) < 2147483648 -> X < 31 -> X != 31 + if (RHSLog2 == TypeBits - 1) { + if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_EQ; + else if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_NE; } - } - break; - } - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - if (ShAmt->uge(TypeBits)) - break; + return new ICmpInst(Pred, X, ConstantInt::get(RHS->getType(), RHSLog2)); + } else if (ICI.isSigned()) { + if (RHSV->isAllOnesValue()) { + // (1 << X) <= -1 -> X == 31 + if (Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(RHS->getType(), TypeBits - 1)); - if (ICI.isEquality()) { - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - Constant *Comp = - ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), - ShAmt); - if (Comp != RHS) {// Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = Builder->getInt1(IsICMP_NE); - return replaceInstUsesWith(ICI, Cst); - } + // (1 << X) > -1 -> X != 31 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(RHS->getType(), TypeBits - 1)); + } else if (!(*RHSV)) { + // (1 << X) < 0 -> X == 31 + // (1 << X) <= 0 -> X == 31 + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(RHS->getType(), TypeBits - 1)); - // If the shift is NUW, then it is just shifting out zeros, no need for an - // AND. - if (cast(LHSI)->hasNoUnsignedWrap()) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); - - // If the shift is NSW and we compare to 0, then it is just shifting out - // sign bits, no need for an AND either. - if (cast(LHSI)->hasNoSignedWrap() && *RHSV == 0) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); - - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - Constant *Mask = Builder->getInt(APInt::getLowBitsSet(TypeBits, - TypeBits - ShAmtVal)); - - Value *And = - Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getLShr(RHS, ShAmt)); + // (1 << X) >= 0 -> X != 31 + // (1 << X) > 0 -> X != 31 + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(RHS->getType(), TypeBits - 1)); + } + } else if (ICI.isEquality()) { + if (RHSVIsPowerOf2) + return new ICmpInst( + Pred, X, ConstantInt::get(RHS->getType(), RHSV->logBase2())); } } + return nullptr; + } + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + if (ShAmt->uge(TypeBits)) + return nullptr; - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && - cast(LHSI)->hasNoSignedWrap()) - return new ICmpInst(pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); - - // Otherwise, if this is a comparison of the sign bit, simplify to and/test. - bool TrueIfSigned = false; - if (LHSI->hasOneUse() && - isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { - // (X << 31) (X&1) != 0 - Constant *Mask = ConstantInt::get(LHSI->getOperand(0)->getType(), - APInt::getOneBitSet(TypeBits, - TypeBits-ShAmt->getZExtValue()-1)); - Value *And = - Builder->CreateAnd(LHSI->getOperand(0), Mask, LHSI->getName()+".mask"); - return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, - And, Constant::getNullValue(And->getType())); + if (ICI.isEquality()) { + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + Constant *Comp = + ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), ShAmt); + if (Comp != RHS) { // Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = Builder->getInt1(IsICMP_NE); + return replaceInstUsesWith(ICI, Cst); } - // Transform (icmp pred iM (shl iM %v, N), CI) - // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (CI>>N)) - // Transform the shl to a trunc if (trunc (CI>>N)) has no loss and M-N. - // This enables to get rid of the shift in favor of a trunc which can be - // free on the target. It has the additional benefit of comparing to a - // smaller constant, which will be target friendly. - unsigned Amt = ShAmt->getLimitedValue(TypeBits-1); - if (LHSI->hasOneUse() && - Amt != 0 && RHSV->countTrailingZeros() >= Amt) { - Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt); - Constant *NCI = ConstantExpr::getTrunc( - ConstantExpr::getAShr(RHS, - ConstantInt::get(RHS->getType(), Amt)), - NTy); - return new ICmpInst(ICI.getPredicate(), - Builder->CreateTrunc(LHSI->getOperand(0), NTy), - NCI); + // If the shift is NUW, then it is just shifting out zeros, no need for an + // AND. + if (cast(LHSI)->hasNoUnsignedWrap()) + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + ConstantExpr::getLShr(RHS, ShAmt)); + + // If the shift is NSW and we compare to 0, then it is just shifting out + // sign bits, no need for an AND either. + if (cast(LHSI)->hasNoSignedWrap() && *RHSV == 0) + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + ConstantExpr::getLShr(RHS, ShAmt)); + + if (LHSI->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + Constant *Mask = + Builder->getInt(APInt::getLowBitsSet(TypeBits, TypeBits - ShAmtVal)); + + Value *And = Builder->CreateAnd(LHSI->getOperand(0), Mask, + LHSI->getName() + ".mask"); + return new ICmpInst(ICI.getPredicate(), And, + ConstantExpr::getLShr(RHS, ShAmt)); } + } - break; + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead. + ICmpInst::Predicate pred = ICI.getPredicate(); + if (isSignTest(pred, RHS) && cast(LHSI)->hasNoSignedWrap()) + return new ICmpInst(pred, LHSI->getOperand(0), + Constant::getNullValue(RHS->getType())); + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (LHSI->hasOneUse() && + isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { + // (X << 31) (X&1) != 0 + Constant *Mask = ConstantInt::get( + LHSI->getOperand(0)->getType(), + APInt::getOneBitSet(TypeBits, TypeBits - ShAmt->getZExtValue() - 1)); + Value *And = Builder->CreateAnd(LHSI->getOperand(0), Mask, + LHSI->getName() + ".mask"); + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(And->getType())); } - case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) - case Instruction::AShr: { - // Handle equality comparisons of shift-by-constant. - BinaryOperator *BO = cast(LHSI); - if (ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1))) { - if (Instruction *Res = foldICmpShrConst(ICI, BO, ShAmt)) - return Res; - } + // Transform (icmp pred iM (shl iM %v, N), CI) + // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (CI>>N)) + // Transform the shl to a trunc if (trunc (CI>>N)) has no loss and M-N. + // This enables to get rid of the shift in favor of a trunc which can be + // free on the target. It has the additional benefit of comparing to a + // smaller constant, which will be target friendly. + unsigned Amt = ShAmt->getLimitedValue(TypeBits - 1); + if (LHSI->hasOneUse() && Amt != 0 && RHSV->countTrailingZeros() >= Amt) { + Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt); + Constant *NCI = ConstantExpr::getTrunc( + ConstantExpr::getAShr(RHS, ConstantInt::get(RHS->getType(), Amt)), NTy); + return new ICmpInst(ICI.getPredicate(), + Builder->CreateTrunc(LHSI->getOperand(0), NTy), NCI); + } + + return nullptr; +} - // Handle exact shr's. - if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) { - if (RHSV->isMinValue()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS); +Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; + + // Handle equality comparisons of shift-by-constant. + BinaryOperator *BO = cast(LHSI); + if (ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1))) { + if (Instruction *Res = foldICmpShrConstConst(ICI, BO, ShAmt)) + return Res; + } + + // Handle exact shr's. + if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) { + if (RHSV->isMinValue()) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS); + } + + return nullptr; +} + +Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &ICI, + Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; + + if (ConstantInt *DivLHS = dyn_cast(LHSI->getOperand(0))) { + Value *X = LHSI->getOperand(1); + const APInt &C1 = RHS->getValue(); + const APInt &C2 = DivLHS->getValue(); + assert(C2 != 0 && "udiv 0, X should have been simplified already."); + // (icmp ugt (udiv C2, X), C1) -> (icmp ule X, C2/(C1+1)) + if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { + assert(!C1.isMaxValue() && + "icmp ugt X, UINT_MAX should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_ULE, X, + ConstantInt::get(X->getType(), C2.udiv(C1 + 1))); + } + // (icmp ult (udiv C2, X), C1) -> (icmp ugt X, C2/C1) + if (ICI.getPredicate() == ICmpInst::ICMP_ULT) { + assert(C1 != 0 && "icmp ult X, 0 should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), C2.udiv(C1))); } - break; } - case Instruction::UDiv: - if (ConstantInt *DivLHS = dyn_cast(LHSI->getOperand(0))) { - Value *X = LHSI->getOperand(1); - const APInt &C1 = RHS->getValue(); - const APInt &C2 = DivLHS->getValue(); - assert(C2 != 0 && "udiv 0, X should have been simplified already."); - // (icmp ugt (udiv C2, X), C1) -> (icmp ule X, C2/(C1+1)) - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - assert(!C1.isMaxValue() && - "icmp ugt X, UINT_MAX should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_ULE, X, - ConstantInt::get(X->getType(), C2.udiv(C1 + 1))); + return nullptr; +} + +Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; + + // Fold: icmp pred ([us]div X, C1), C2 -> range test + // Fold this div into the comparison, producing a range check. + // Determine, based on the divide type, what the range is being + // checked. If there is an overflow on the low or high side, remember + // it, otherwise compute the range [low, hi) bounding the new value. + // See: InsertRangeTest above for the kinds of replacements possible. + if (ConstantInt *DivRHS = dyn_cast(LHSI->getOperand(1))) + if (Instruction *R = + foldICmpDivConstConst(ICI, cast(LHSI), DivRHS)) + return R; + + return nullptr; +} + +Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; + + ConstantInt *LHSC = dyn_cast(LHSI->getOperand(0)); + if (!LHSC) + return nullptr; + + const APInt &LHSV = LHSC->getValue(); + + // C1-X (X|(C2-1)) == C1 + // iff C1 & (C2-1) == C2-1 + // C2 is a power of 2 + if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && + RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == (*RHSV - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, + Builder->CreateOr(LHSI->getOperand(1), *RHSV - 1), + LHSC); + + // C1-X >u C2 -> (X|C2) != C1 + // iff C1 & C2 == C2 + // C2+1 is a power of 2 + if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && + (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == *RHSV) + return new ICmpInst(ICmpInst::ICMP_NE, + Builder->CreateOr(LHSI->getOperand(1), *RHSV), LHSC); + + return nullptr; +} + +Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV) { + // FIXME: This check restricts all folds under here to scalar types. + ConstantInt *RHS = dyn_cast(ICI.getOperand(1)); + if (!RHS) + return nullptr; + + // Fold: icmp pred (add X, C1), C2 + if (!ICI.isEquality()) { + ConstantInt *LHSC = dyn_cast(LHSI->getOperand(1)); + if (!LHSC) + return nullptr; + + const APInt &LHSV = LHSC->getValue(); + ConstantRange CR = + ICI.makeConstantRange(ICI.getPredicate(), *RHSV).subtract(LHSV); + + if (ICI.isSigned()) { + if (CR.getLower().isSignBit()) { + return new ICmpInst(ICmpInst::ICMP_SLT, LHSI->getOperand(0), + Builder->getInt(CR.getUpper())); + } else if (CR.getUpper().isSignBit()) { + return new ICmpInst(ICmpInst::ICMP_SGE, LHSI->getOperand(0), + Builder->getInt(CR.getLower())); } - // (icmp ult (udiv C2, X), C1) -> (icmp ugt X, C2/C1) - if (ICI.getPredicate() == ICmpInst::ICMP_ULT) { - assert(C1 != 0 && "icmp ult X, 0 should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), C2.udiv(C1))); + } else { + if (CR.getLower().isMinValue()) { + return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), + Builder->getInt(CR.getUpper())); + } else if (CR.getUpper().isMinValue()) { + return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), + Builder->getInt(CR.getLower())); } } - // fall-through - case Instruction::SDiv: - // Fold: icmp pred ([us]div X, C1), C2 -> range test - // Fold this div into the comparison, producing a range check. - // Determine, based on the divide type, what the range is being - // checked. If there is an overflow on the low or high side, remember - // it, otherwise compute the range [low, hi) bounding the new value. - // See: InsertRangeTest above for the kinds of replacements possible. - if (ConstantInt *DivRHS = dyn_cast(LHSI->getOperand(1))) - if (Instruction *R = foldICmpDivConst(ICI, cast(LHSI), - DivRHS)) - return R; - break; - - case Instruction::Sub: { - ConstantInt *LHSC = dyn_cast(LHSI->getOperand(0)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); - // C1-X (X|(C2-1)) == C1 - // iff C1 & (C2-1) == C2-1 + // X-C1 (X & -C2) == C1 + // iff C1 & (C2-1) == 0 // C2 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == (*RHSV - 1)) + RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == 0) return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateOr(LHSI->getOperand(1), *RHSV - 1), - LHSC); + Builder->CreateAnd(LHSI->getOperand(0), -(*RHSV)), + ConstantExpr::getNeg(LHSC)); - // C1-X >u C2 -> (X|C2) != C1 - // iff C1 & C2 == C2 + // X-C1 >u C2 -> (X & ~C2) != C1 + // iff C1 & C2 == 0 // C2+1 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == *RHSV) + (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == 0) return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateOr(LHSI->getOperand(1), *RHSV), LHSC); - break; + Builder->CreateAnd(LHSI->getOperand(0), ~(*RHSV)), + ConstantExpr::getNeg(LHSC)); } + return nullptr; +} - case Instruction::Add: - // Fold: icmp pred (add X, C1), C2 - if (!ICI.isEquality()) { - ConstantInt *LHSC = dyn_cast(LHSI->getOperand(1)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); - - ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), *RHSV) - .subtract(LHSV); - - if (ICI.isSigned()) { - if (CR.getLower().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SLT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } else { - if (CR.getLower().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C. +Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI) { + Instruction *LHSI; + const APInt *RHSV; + if (!match(ICI.getOperand(0), m_Instruction(LHSI)) || + !match(ICI.getOperand(1), m_APInt(RHSV))) + return nullptr; - // X-C1 (X & -C2) == C1 - // iff C1 & (C2-1) == 0 - // C2 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateAnd(LHSI->getOperand(0), -(*RHSV)), - ConstantExpr::getNeg(LHSC)); - - // X-C1 >u C2 -> (X & ~C2) != C1 - // iff C1 & C2 == 0 - // C2+1 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateAnd(LHSI->getOperand(0), ~(*RHSV)), - ConstantExpr::getNeg(LHSC)); - } + switch (LHSI->getOpcode()) { + case Instruction::Trunc: + if (Instruction *I = foldICmpTruncConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::Xor: + if (Instruction *I = foldICmpXorConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::And: + if (Instruction *I = foldICmpAndConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::Or: + if (Instruction *I = foldICmpOrConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::Mul: + if (Instruction *I = foldICmpMulConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::Shl: + if (Instruction *I = foldICmpShlConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::LShr: + case Instruction::AShr: + if (Instruction *I = foldICmpShrConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::UDiv: + if (Instruction *I = foldICmpUDivConstant(ICI, LHSI, RHSV)) + return I; + // fall-through + case Instruction::SDiv: + if (Instruction *I = foldICmpDivConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::Sub: + if (Instruction *I = foldICmpSubConstant(ICI, LHSI, RHSV)) + return I; + break; + case Instruction::Add: + if (Instruction *I = foldICmpAddConstant(ICI, LHSI, RHSV)) + return I; break; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 88c52e2..7ce5ac1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -548,10 +548,10 @@ private: ConstantInt *AndCst = nullptr); Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC); - Instruction *foldICmpDivConst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS); - Instruction *foldICmpShrConst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS); + Instruction *foldICmpDivConstConst(ICmpInst &ICI, BinaryOperator *DivI, + ConstantInt *DivRHS); + Instruction *foldICmpShrConstConst(ICmpInst &ICI, BinaryOperator *DivI, + ConstantInt *DivRHS); Instruction *foldICmpCstShrConst(ICmpInst &I, Value *Op, Value *A, ConstantInt *CI1, ConstantInt *CI2); Instruction *foldICmpCstShlConst(ICmpInst &I, Value *Op, Value *A, @@ -560,6 +560,30 @@ private: ICmpInst::Predicate Pred); Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); Instruction *foldICmpWithConstant(ICmpInst &ICI); + + Instruction *foldICmpTruncConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpAndConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpXorConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpOrConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpMulConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpShlConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpShrConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpUDivConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpDivConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpSubConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpAddConstant(ICmpInst &ICI, Instruction *LHSI, + const APInt *RHSV); + Instruction *foldICmpEqualityWithConstant(ICmpInst &ICI); Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI); -- 2.7.4