[InstCombine] use m_APInt in foldICmpWithConstant; NFCI
authorSanjay Patel <spatel@rotateright.com>
Tue, 16 Aug 2016 16:08:11 +0000 (16:08 +0000)
committerSanjay Patel <spatel@rotateright.com>
Tue, 16 Aug 2016 16:08:11 +0000 (16:08 +0000)
There's some formatting and pointer deref ugliness here that I intend to fix in
subsequent patches. The overall goal is to refactor the obnoxiously long switch
and incrementally remove the restriction to scalar types (allow folds for vector
splats). This patch introduces the use of m_APInt which means the RHSV reference
is now a pointer (and may have matched a vector splat), but the check of 'RHS'
remains, so vector folds are disallowed and no functional change is intended.

llvm-svn: 278816

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h

index 847b4a1..95dab35 100644 (file)
@@ -1530,15 +1530,22 @@ Instruction *InstCombiner::foldICmpCstShlConst(ICmpInst &I, Value *Op, Value *A,
   return getConstant(false);
 }
 
-/// Handle "icmp (instr, intcst)".
-Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
-                                                Instruction *LHSI,
-                                                ConstantInt *RHS) {
-  const APInt &RHSV = RHS->getValue();
+/// Try to fold integer comparisons with a constant operand: icmp Pred X, C.
+Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI) {
+  Instruction *LHSI;
+  const APInt *RHSV;
+  if (!match(ICI.getOperand(0), m_Instruction(LHSI)) ||
+      !match(ICI.getOperand(1), m_APInt(RHSV)))
+    return nullptr;
+
+  // FIXME: This check restricts all folds under here to scalar types.
+  ConstantInt *RHS = dyn_cast<ConstantInt>(ICI.getOperand(1));
+  if (!RHS)
+    return nullptr;
 
   switch (LHSI->getOpcode()) {
   case Instruction::Trunc:
-    if (RHS->isOne() && RHSV.getBitWidth() > 1) {
+    if (RHS->isOne() && RHSV->getBitWidth() > 1) {
       // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1
       Value *V = nullptr;
       if (ICI.getPredicate() == ICmpInst::ICMP_SLT &&
@@ -1569,8 +1576,8 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
     if (ConstantInt *XorCst = dyn_cast<ConstantInt>(LHSI->getOperand(1))) {
       // If this is a comparison that tests the signbit (X < 0) or (x > -1),
       // fold the xor.
-      if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) ||
-          (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) {
+      if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && *RHSV == 0) ||
+          (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV->isAllOnesValue())) {
         Value *CompareVal = LHSI->getOperand(0);
 
         // If the sign bit of the XorCst is not set, there is no change to
@@ -1603,7 +1610,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
                                          ? ICI.getUnsignedPredicate()
                                          : ICI.getSignedPredicate();
           return new ICmpInst(Pred, LHSI->getOperand(0),
-                              Builder->getInt(RHSV ^ SignBit));
+                              Builder->getInt(*RHSV ^ SignBit));
         }
 
         // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A)
@@ -1614,20 +1621,20 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
                                          : ICI.getSignedPredicate();
           Pred = ICI.getSwappedPredicate(Pred);
           return new ICmpInst(Pred, LHSI->getOperand(0),
-                              Builder->getInt(RHSV ^ NotSignBit));
+                              Builder->getInt(*RHSV ^ NotSignBit));
         }
       }
 
       // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C)
       //   iff -C is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT &&
-          XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2())
+          XorCst->getValue() == ~(*RHSV) && (*RHSV + 1).isPowerOf2())
         return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst);
 
       // (icmp ult (xor X, C), -C) -> (icmp uge X, C)
       //   iff -C is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_ULT &&
-          XorCst->getValue() == -RHSV && RHSV.isPowerOf2())
+          XorCst->getValue() == -(*RHSV) && RHSV->isPowerOf2())
         return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst);
     }
     break;
@@ -1645,7 +1652,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
         // Extending a relational comparison when we're checking the sign
         // bit would not work.
         if (ICI.isEquality() ||
-            (!AndCst->isNegative() && RHSV.isNonNegative())) {
+            (!AndCst->isNegative() && RHSV->isNonNegative())) {
           Value *NewAnd =
             Builder->CreateAnd(Cast->getOperand(0),
                                ConstantExpr::getZExt(AndCst, Cast->getSrcTy()));
@@ -1661,7 +1668,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
         IntegerType *Ty = cast<IntegerType>(Cast->getSrcTy());
         // Make sure we don't compare the upper bits, SimplifyDemandedBits
         // should fold the icmp to true/false in that case.
-        if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) {
+        if (ICI.isEquality() && RHSV->getActiveBits() <= Ty->getBitWidth()) {
           Value *NewAnd =
             Builder->CreateAnd(Cast->getOperand(0),
                                ConstantExpr::getTrunc(AndCst, Ty));
@@ -1754,7 +1761,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
       // Turn ((X >> Y) & C) == 0  into  (X & (C << Y)) == 0.  The later is
       // preferable because it allows the C<<Y expression to be hoisted out
       // of a loop if Y is invariant and X is not.
-      if (Shift && Shift->hasOneUse() && RHSV == 0 &&
+      if (Shift && Shift->hasOneUse() && *RHSV == 0 &&
           ICI.isEquality() && !Shift->isArithmeticShift() &&
           !isa<Constant>(Shift->getOperand(0))) {
         // Compute C << Y.
@@ -1780,7 +1787,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
       // iff pred isn't signed
       {
         Value *X, *Y, *LShr;
-        if (!ICI.isSigned() && RHSV == 0) {
+        if (!ICI.isSigned() && *RHSV == 0) {
           if (match(LHSI->getOperand(1), m_One())) {
             Constant *One = cast<Constant>(LHSI->getOperand(1));
             Value *Or = LHSI->getOperand(0);
@@ -1821,7 +1828,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
         unsigned NTZ = AndCst->getValue().countTrailingZeros();
         if ((NTZ < AndCst->getBitWidth()) &&
-            APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV))
+            APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(*RHSV))
           return new ICmpInst(ICmpInst::ICMP_NE, LHSI,
                               Constant::getNullValue(RHS->getType()));
       }
@@ -1843,7 +1850,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
     // X & -C == -C -> X >  u ~C
     // X & -C != -C -> X <= u ~C
     //   iff C is a power of 2
-    if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2())
+    if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-(*RHSV)).isPowerOf2())
       return new ICmpInst(
           ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT
                                                   : ICmpInst::ICMP_ULE,
@@ -1915,13 +1922,13 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
   }
 
   case Instruction::Shl: {       // (icmp pred (shl X, ShAmt), CI)
-    uint32_t TypeBits = RHSV.getBitWidth();
+    uint32_t TypeBits = RHSV->getBitWidth();
     ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1));
     if (!ShAmt) {
       Value *X;
       // (1 << X) pred P2 -> X pred Log2(P2)
       if (match(LHSI, m_Shl(m_One(), m_Value(X)))) {
-        bool RHSVIsPowerOf2 = RHSV.isPowerOf2();
+        bool RHSVIsPowerOf2 = RHSV->isPowerOf2();
         ICmpInst::Predicate Pred = ICI.getPredicate();
         if (ICI.isUnsigned()) {
           if (!RHSVIsPowerOf2) {
@@ -1934,7 +1941,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
             else if (Pred == ICmpInst::ICMP_UGE)
               Pred = ICmpInst::ICMP_UGT;
           }
-          unsigned RHSLog2 = RHSV.logBase2();
+          unsigned RHSLog2 = RHSV->logBase2();
 
           // (1 << X) >= 2147483648 -> X >= 31 -> X == 31
           // (1 << X) <  2147483648 -> X <  31 -> X != 31
@@ -1948,7 +1955,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
           return new ICmpInst(Pred, X,
                               ConstantInt::get(RHS->getType(), RHSLog2));
         } else if (ICI.isSigned()) {
-          if (RHSV.isAllOnesValue()) {
+          if (RHSV->isAllOnesValue()) {
             // (1 << X) <= -1 -> X == 31
             if (Pred == ICmpInst::ICMP_SLE)
               return new ICmpInst(ICmpInst::ICMP_EQ, X,
@@ -1958,7 +1965,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
             if (Pred == ICmpInst::ICMP_SGT)
               return new ICmpInst(ICmpInst::ICMP_NE, X,
                                   ConstantInt::get(RHS->getType(), TypeBits-1));
-          } else if (!RHSV) {
+          } else if (!(*RHSV)) {
             // (1 << X) <  0 -> X == 31
             // (1 << X) <= 0 -> X == 31
             if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
@@ -1974,7 +1981,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
         } else if (ICI.isEquality()) {
           if (RHSVIsPowerOf2)
             return new ICmpInst(
-                Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2()));
+                Pred, X, ConstantInt::get(RHS->getType(), RHSV->logBase2()));
         }
       }
       break;
@@ -2006,7 +2013,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
 
       // If the shift is NSW and we compare to 0, then it is just shifting out
       // sign bits, no need for an AND either.
-      if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0)
+      if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && *RHSV == 0)
         return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
                             ConstantExpr::getLShr(RHS, ShAmt));
 
@@ -2054,7 +2061,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
     // smaller constant, which will be target friendly.
     unsigned Amt = ShAmt->getLimitedValue(TypeBits-1);
     if (LHSI->hasOneUse() &&
-        Amt != 0 && RHSV.countTrailingZeros() >= Amt) {
+        Amt != 0 && RHSV->countTrailingZeros() >= Amt) {
       Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt);
       Constant *NCI = ConstantExpr::getTrunc(
                         ConstantExpr::getAShr(RHS,
@@ -2079,7 +2086,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
 
     // Handle exact shr's.
     if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) {
-      if (RHSV.isMinValue())
+      if (RHSV->isMinValue())
         return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS);
     }
     break;
@@ -2128,18 +2135,18 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
     //   iff C1 & (C2-1) == C2-1
     //       C2 is a power of 2
     if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() &&
-        RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == (RHSV - 1))
+        RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == (*RHSV - 1))
       return new ICmpInst(ICmpInst::ICMP_EQ,
-                          Builder->CreateOr(LHSI->getOperand(1), RHSV - 1),
+                          Builder->CreateOr(LHSI->getOperand(1), *RHSV - 1),
                           LHSC);
 
     // C1-X >u C2 -> (X|C2) != C1
     //   iff C1 & C2 == C2
     //       C2+1 is a power of 2
     if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() &&
-        (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == RHSV)
+        (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == *RHSV)
       return new ICmpInst(ICmpInst::ICMP_NE,
-                          Builder->CreateOr(LHSI->getOperand(1), RHSV), LHSC);
+                          Builder->CreateOr(LHSI->getOperand(1), *RHSV), LHSC);
     break;
   }
 
@@ -2150,7 +2157,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
       if (!LHSC) break;
       const APInt &LHSV = LHSC->getValue();
 
-      ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), RHSV)
+      ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), *RHSV)
                             .subtract(LHSV);
 
       if (ICI.isSigned()) {
@@ -2175,18 +2182,18 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI,
       //   iff C1 & (C2-1) == 0
       //       C2 is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() &&
-          RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == 0)
+          RHSV->isPowerOf2() && (LHSV & (*RHSV - 1)) == 0)
         return new ICmpInst(ICmpInst::ICMP_EQ,
-                            Builder->CreateAnd(LHSI->getOperand(0), -RHSV),
+                            Builder->CreateAnd(LHSI->getOperand(0), -(*RHSV)),
                             ConstantExpr::getNeg(LHSC));
 
       // X-C1 >u C2 -> (X & ~C2) != C1
       //   iff C1 & C2 == 0
       //       C2+1 is a power of 2
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() &&
-          (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == 0)
+          (*RHSV + 1).isPowerOf2() && (LHSV & *RHSV) == 0)
         return new ICmpInst(ICmpInst::ICMP_NE,
-                            Builder->CreateAnd(LHSI->getOperand(0), ~RHSV),
+                            Builder->CreateAnd(LHSI->getOperand(0), ~(*RHSV)),
                             ConstantExpr::getNeg(LHSC));
     }
     break;
@@ -3627,17 +3634,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
   // See if we are doing a comparison between a constant and an instruction that
   // can be folded into the comparison.
 
-  // FIXME: Use m_APInt instead of dyn_cast<ConstantInt> to allow these
-  // transforms for vectors.
-
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
-    // Since the RHS is a ConstantInt (CI), if the left hand side is an
-    // instruction, see if that instruction also has constants so that the
-    // instruction can be folded into the icmp
-    if (Instruction *LHSI = dyn_cast<Instruction>(Op0))
-      if (Instruction *Res = foldICmpWithConstant(I, LHSI, CI))
-        return Res;
-  }
+  if (Instruction *Res = foldICmpWithConstant(I))
+    return Res;
 
   if (Instruction *Res = foldICmpEqualityWithConstant(I))
     return Res;
index c555ff8..88c52e2 100644 (file)
@@ -559,8 +559,7 @@ private:
   Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI,
                                   ICmpInst::Predicate Pred);
   Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
-  Instruction *foldICmpWithConstant(ICmpInst &ICI, Instruction *LHS,
-                                    ConstantInt *RHS);
+  Instruction *foldICmpWithConstant(ICmpInst &ICI);
   Instruction *foldICmpEqualityWithConstant(ICmpInst &ICI);
   Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI);