[InstCombine] change param type from Instruction to BinaryOperator for icmp helpers...
authorSanjay Patel <spatel@rotateright.com>
Mon, 22 Aug 2016 21:24:29 +0000 (21:24 +0000)
committerSanjay Patel <spatel@rotateright.com>
Mon, 22 Aug 2016 21:24:29 +0000 (21:24 +0000)
This saves some casting in the helper functions and eases some further refactoring.

llvm-svn: 279478

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h

index 68e55c4..e6879d5 100644 (file)
@@ -1569,7 +1569,8 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp,
 }
 
 /// Fold icmp (xor X, Y), C.
-Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, Instruction *Xor,
+Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
+                                               BinaryOperator *Xor,
                                                const APInt *C) {
   Value *X = Xor->getOperand(0);
   Value *Y = Xor->getOperand(1);
@@ -1634,7 +1635,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, Instruction *Xor,
   return nullptr;
 }
 
-Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &ICI, Instruction *LHSI,
+Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &ICI,
+                                               BinaryOperator *LHSI,
                                                const APInt *RHSV) {
   // FIXME: This check restricts all folds under here to scalar types.
   ConstantInt *RHS = dyn_cast<ConstantInt>(ICI.getOperand(1));
@@ -1875,7 +1877,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &ICI, Instruction *LHSI,
 }
 
 /// Fold icmp (or X, Y), C.
-Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, Instruction *Or,
+Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
                                               const APInt *C) {
   ICmpInst::Predicate Pred = Cmp.getPredicate();
   if (*C == 1) {
@@ -1906,7 +1908,8 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, Instruction *Or,
 }
 
 /// Fold icmp (mul X, Y), C.
-Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, Instruction *Mul,
+Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp,
+                                               BinaryOperator *Mul,
                                                const APInt *C) {
   const APInt *MulC;
   if (!match(Mul->getOperand(1), m_APInt(MulC)))
@@ -1915,7 +1918,7 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, Instruction *Mul,
   // If this is a test of the sign bit and the multiply is sign-preserving with
   // a constant operand, use the multiply LHS operand instead.
   ICmpInst::Predicate Pred = Cmp.getPredicate();
-  if (isSignTest(Pred, *C) && cast<BinaryOperator>(Mul)->hasNoSignedWrap()) {
+  if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) {
     if (MulC->isNegative())
       Pred = ICmpInst::getSwappedPredicate(Pred);
     return new ICmpInst(Pred, Mul->getOperand(0),
@@ -1988,7 +1991,8 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
 }
 
 /// Fold icmp (shl X, Y), C.
-Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, Instruction *Shl,
+Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
+                                               BinaryOperator *Shl,
                                                const APInt *C) {
   const APInt *ShiftAmt;
   if (!match(Shl->getOperand(1), m_APInt(ShiftAmt)))
@@ -2006,12 +2010,12 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, Instruction *Shl,
     // If the shift is NUW, then it is just shifting out zeros, no need for an
     // AND.
     Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt));
-    if (cast<BinaryOperator>(Shl)->hasNoUnsignedWrap())
+    if (Shl->hasNoUnsignedWrap())
       return new ICmpInst(Pred, X, LShrC);
 
     // If the shift is NSW and we compare to 0, then it is just shifting out
     // sign bits, no need for an AND either.
-    if (cast<BinaryOperator>(Shl)->hasNoSignedWrap() && *C == 0)
+    if (Shl->hasNoSignedWrap() && *C == 0)
       return new ICmpInst(Pred, X, LShrC);
 
     if (Shl->hasOneUse()) {
@@ -2027,7 +2031,7 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, Instruction *Shl,
   // If this is a signed comparison to 0 and the shift is sign preserving,
   // use the shift LHS operand instead; isSignTest may change 'Pred', so only
   // do that if we're sure to not continue on in this function.
-  if (cast<BinaryOperator>(Shl)->hasNoSignedWrap() && isSignTest(Pred, *C))
+  if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C))
     return new ICmpInst(Pred, X, Constant::getNullValue(X->getType()));
 
   // Otherwise, if this is a comparison of the sign bit, simplify to and/test.
@@ -2062,22 +2066,22 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, Instruction *Shl,
 }
 
 /// Fold icmp ({al}shr X, Y), C.
-Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &ICI, Instruction *LHSI,
-                                               const APInt *RHSV) {
+Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
+                                               BinaryOperator *Shr,
+                                               const APInt *C) {
   // An exact shr only shifts out zero bits, so:
   // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0
-  CmpInst::Predicate Pred = ICI.getPredicate();
-  BinaryOperator *BO = cast<BinaryOperator>(LHSI);
-  if (ICI.isEquality() && BO->isExact() && BO->hasOneUse() && *RHSV == 0)
-    return new ICmpInst(Pred, BO->getOperand(0), ICI.getOperand(1));
+  CmpInst::Predicate Pred = Cmp.getPredicate();
+  if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && *C == 0)
+    return new ICmpInst(Pred, Shr->getOperand(0), Cmp.getOperand(1));
 
   // FIXME: This check restricts all folds under here to scalar types.
   // Handle equality comparisons of shift-by-constant.
-  ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1));
+  ConstantInt *ShAmt = dyn_cast<ConstantInt>(Shr->getOperand(1));
   if (!ShAmt)
     return nullptr;
 
-  if (Instruction *Res = foldICmpShrConstConst(ICI, BO, ShAmt))
+  if (Instruction *Res = foldICmpShrConstConst(Cmp, Shr, ShAmt))
     return Res;
 
   return nullptr;
@@ -2085,7 +2089,7 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &ICI, Instruction *LHSI,
 
 /// Fold icmp (udiv X, Y), C.
 Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
-                                                Instruction *UDiv,
+                                                BinaryOperator *UDiv,
                                                 const APInt *C) {
   const APInt *C2;
   if (!match(UDiv->getOperand(0), m_APInt(C2)))
@@ -2112,7 +2116,8 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
   return nullptr;
 }
 
-Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &ICI, Instruction *LHSI,
+Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &ICI,
+                                               BinaryOperator *LHSI,
                                                const APInt *RHSV) {
   // FIXME: This check restricts all folds under here to scalar types.
   ConstantInt *RHS = dyn_cast<ConstantInt>(ICI.getOperand(1));
@@ -2127,14 +2132,15 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &ICI, Instruction *LHSI,
   // See: InsertRangeTest above for the kinds of replacements possible.
   if (ConstantInt *DivRHS = dyn_cast<ConstantInt>(LHSI->getOperand(1)))
     if (Instruction *R =
-            foldICmpDivConstConst(ICI, cast<BinaryOperator>(LHSI), DivRHS))
+            foldICmpDivConstConst(ICI, LHSI, DivRHS))
       return R;
 
   return nullptr;
 }
 
 /// Fold icmp (sub X, Y), C.
-Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, Instruction *Sub,
+Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
+                                               BinaryOperator *Sub,
                                                const APInt *C) {
   const APInt *C2;
   if (!match(Sub->getOperand(0), m_APInt(C2)) || !Sub->hasOneUse())
@@ -2162,7 +2168,8 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, Instruction *Sub,
 }
 
 /// Fold icmp (add X, Y), C.
-Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, Instruction *Add,
+Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
+                                               BinaryOperator *Add,
                                                const APInt *C) {
   Value *Y = Add->getOperand(1);
   const APInt *C2;
@@ -2210,61 +2217,66 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, Instruction *Add,
 }
 
 /// Try to fold integer comparisons with a constant operand: icmp Pred X, C.
-Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &ICI) {
-  Instruction *LHSI;
-  const APInt *RHSV;
-  if (!match(ICI.getOperand(0), m_Instruction(LHSI)) ||
-      !match(ICI.getOperand(1), m_APInt(RHSV)))
+Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
+  const APInt *C;
+  if (!match(Cmp.getOperand(1), m_APInt(C)))
     return nullptr;
 
-  switch (LHSI->getOpcode()) {
-  case Instruction::Trunc:
-    if (Instruction *I = foldICmpTruncConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::Xor:
-    if (Instruction *I = foldICmpXorConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::And:
-    if (Instruction *I = foldICmpAndConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::Or:
-    if (Instruction *I = foldICmpOrConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::Mul:
-    if (Instruction *I = foldICmpMulConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::Shl:
-    if (Instruction *I = foldICmpShlConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::LShr:
-  case Instruction::AShr:
-    if (Instruction *I = foldICmpShrConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::UDiv:
-    if (Instruction *I = foldICmpUDivConstant(ICI, LHSI, RHSV))
-      return I;
-    LLVM_FALLTHROUGH;
-  case Instruction::SDiv:
-    if (Instruction *I = foldICmpDivConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::Sub:
-    if (Instruction *I = foldICmpSubConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
-  case Instruction::Add:
-    if (Instruction *I = foldICmpAddConstant(ICI, LHSI, RHSV))
-      return I;
-    break;
+  BinaryOperator *BO;
+  if (match(Cmp.getOperand(0), m_BinOp(BO))) {
+    switch (BO->getOpcode()) {
+    case Instruction::Xor:
+      if (Instruction *I = foldICmpXorConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::And:
+      if (Instruction *I = foldICmpAndConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::Or:
+      if (Instruction *I = foldICmpOrConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::Mul:
+      if (Instruction *I = foldICmpMulConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::Shl:
+      if (Instruction *I = foldICmpShlConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::LShr:
+    case Instruction::AShr:
+      if (Instruction *I = foldICmpShrConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::UDiv:
+      if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
+        return I;
+      LLVM_FALLTHROUGH;
+    case Instruction::SDiv:
+      if (Instruction *I = foldICmpDivConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::Sub:
+      if (Instruction *I = foldICmpSubConstant(Cmp, BO, C))
+        return I;
+      break;
+    case Instruction::Add:
+      if (Instruction *I = foldICmpAddConstant(Cmp, BO, C))
+        return I;
+      break;
+    default:
+      break;
+    }
   }
 
+  Instruction *LHSI;
+  if (match(Cmp.getOperand(0), m_Instruction(LHSI)) &&
+      LHSI->getOpcode() == Instruction::Trunc)
+    if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C))
+      return I;
+
   return nullptr;
 }
 
index 7ce5ac1..f3ee94a 100644 (file)
@@ -559,30 +559,30 @@ private:
   Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI,
                                   ICmpInst::Predicate Pred);
   Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
-  Instruction *foldICmpWithConstant(ICmpInst &ICI);
-
-  Instruction *foldICmpTruncConstant(ICmpInst &ICI, Instruction *LHSI,
-                                     const APInt *RHSV);
-  Instruction *foldICmpAndConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
-  Instruction *foldICmpXorConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
-  Instruction *foldICmpOrConstant(ICmpInst &ICI, Instruction *LHSI,
-                                  const APInt *RHSV);
-  Instruction *foldICmpMulConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
-  Instruction *foldICmpShlConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
-  Instruction *foldICmpShrConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
-  Instruction *foldICmpUDivConstant(ICmpInst &ICI, Instruction *LHSI,
-                                    const APInt *RHSV);
-  Instruction *foldICmpDivConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
-  Instruction *foldICmpSubConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
-  Instruction *foldICmpAddConstant(ICmpInst &ICI, Instruction *LHSI,
-                                   const APInt *RHSV);
+  Instruction *foldICmpWithConstant(ICmpInst &Cmp);
+
+  Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc,
+                                     const APInt *C);
+  Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And,
+                                   const APInt *C);
+  Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor,
+                                   const APInt *C);
+  Instruction *foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
+                                  const APInt *C);
+  Instruction *foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul,
+                                   const APInt *C);
+  Instruction *foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl,
+                                   const APInt *C);
+  Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
+                                   const APInt *C);
+  Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+                                    const APInt *C);
+  Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div,
+                                   const APInt *C);
+  Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub,
+                                   const APInt *C);
+  Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add,
+                                   const APInt *C);
 
   Instruction *foldICmpEqualityWithConstant(ICmpInst &ICI);
   Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI);