From 98cd99dfc671b78ff6d840d474ac9575ad3ea68b Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 18 Aug 2016 21:28:30 +0000 Subject: [PATCH] [InstCombine] add helper function for folds of icmp (shl 1, Y), C; NFCI Clean up the existing code by: 1. Renaming variables 2. Adding local variables 3. Making it vector-safe This is still guarded by a ConstantInt check, so no functional change is intended. But this should be ready to go: if we move the ConstantInt check down, all of these folds should do the right thing for vector types. llvm-svn: 279150 --- .../Transforms/InstCombine/InstCombineCompares.cpp | 127 +++++++++++---------- 1 file changed, 65 insertions(+), 62 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 83064aa..8ad5317 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1924,6 +1924,68 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, Instruction *Mul, return nullptr; } +/// Fold icmp (shl 1, Y), C. +static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, + const APInt *C) { + Value *Y; + if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) + return nullptr; + + Type *ShiftType = Shl->getType(); + uint32_t TypeBits = C->getBitWidth(); + bool CIsPowerOf2 = C->isPowerOf2(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isUnsigned()) { + // (1 << Y) pred C -> Y pred Log2(C) + if (!CIsPowerOf2) { + // (1 << Y) < 30 -> Y <= 4 + // (1 << Y) <= 30 -> Y <= 4 + // (1 << Y) >= 30 -> Y > 4 + // (1 << Y) > 30 -> Y > 4 + if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_ULE; + else if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_UGT; + } + + // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 + // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 + unsigned CLog2 = C->logBase2(); + if (CLog2 == TypeBits - 1) { + if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_EQ; + else if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_NE; + } + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); + } else if (Cmp.isSigned()) { + Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); + if (C->isAllOnesValue()) { + // (1 << Y) <= -1 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) > -1 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } else if (!(*C)) { + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) <= 0 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) >= 0 -> Y != 31 + // (1 << Y) > 0 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } + } else if (Cmp.isEquality() && CIsPowerOf2) { + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2())); + } + + return nullptr; +} + Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &ICI, Instruction *LHSI, const APInt *RHSV) { // FIXME: This check restricts all folds under here to scalar types. @@ -1931,73 +1993,14 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &ICI, Instruction *LHSI, if (!RHS) return nullptr; - uint32_t TypeBits = RHSV->getBitWidth(); ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); - if (!ShAmt) { - Value *X; - // (1 << X) pred P2 -> X pred Log2(P2) - if (match(LHSI, m_Shl(m_One(), m_Value(X)))) { - bool RHSVIsPowerOf2 = RHSV->isPowerOf2(); - ICmpInst::Predicate Pred = ICI.getPredicate(); - if (ICI.isUnsigned()) { - if (!RHSVIsPowerOf2) { - // (1 << X) < 30 -> X <= 4 - // (1 << X) <= 30 -> X <= 4 - // (1 << X) >= 30 -> X > 4 - // (1 << X) > 30 -> X > 4 - if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_ULE; - else if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_UGT; - } - unsigned RHSLog2 = RHSV->logBase2(); - - // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) < 2147483648 -> X < 31 -> X != 31 - if (RHSLog2 == TypeBits - 1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } - - return new ICmpInst(Pred, X, ConstantInt::get(RHS->getType(), RHSLog2)); - } else if (ICI.isSigned()) { - if (RHSV->isAllOnesValue()) { - // (1 << X) <= -1 -> X == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits - 1)); - - // (1 << X) > -1 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits - 1)); - } else if (!(*RHSV)) { - // (1 << X) < 0 -> X == 31 - // (1 << X) <= 0 -> X == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits - 1)); - - // (1 << X) >= 0 -> X != 31 - // (1 << X) > 0 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits - 1)); - } - } else if (ICI.isEquality()) { - if (RHSVIsPowerOf2) - return new ICmpInst( - Pred, X, ConstantInt::get(RHS->getType(), RHSV->logBase2())); - } - } - return nullptr; - } + if (!ShAmt) + return foldICmpShlOne(ICI, LHSI, RHSV); // Check that the shift amount is in range. If not, don't perform // undefined shifts. When the shift is visited it will be // simplified. + unsigned TypeBits = RHSV->getBitWidth(); if (ShAmt->uge(TypeBits)) return nullptr; -- 2.7.4