[InstCombine] recode icmp fold in a vector-friendly way; NFC
authorSanjay Patel <spatel@rotateright.com>
Sun, 4 Sep 2016 14:32:15 +0000 (14:32 +0000)
committerSanjay Patel <spatel@rotateright.com>
Sun, 4 Sep 2016 14:32:15 +0000 (14:32 +0000)
The transform in question:
icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1'

...is still not enabled for vectors, thus no functional change intended.
It's not clear to me if this is a good transform for vectors or even
scalars in general. Changing that behavior may be a follow-on patch.

llvm-svn: 280627

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

index d0675b77444b2ac290f20cfeff88c96ff4e85d18..ad2bd1841f1f1dc6963508616ac47dd4268cab84 100644 (file)
@@ -1498,39 +1498,47 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
 Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
                                                  BinaryOperator *And,
                                                  const APInt *C1) {
-  // FIXME: This check restricts all folds under here to scalar types.
-  ConstantInt *RHS = dyn_cast<ConstantInt>(Cmp.getOperand(1));
-  if (!RHS)
-    return nullptr;
-
-  // FIXME: Use m_APInt.
-  auto *C2 = dyn_cast<ConstantInt>(And->getOperand(1));
-  if (!C2)
+  const APInt *C2;
+  if (!match(And->getOperand(1), m_APInt(C2)))
     return nullptr;
 
   if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse())
     return nullptr;
 
-  // If the LHS is an AND of a truncating cast, we can widen the and/compare to
-  // be the input width without changing the value produced, eliminating a cast.
-  if (TruncInst *Cast = dyn_cast<TruncInst>(And->getOperand(0))) {
-    // We can do this transformation if either the AND constant does not have
-    // its sign bit set or if it is an equality comparison. Extending a
-    // relational comparison when we're checking the sign bit would not work.
-    if (Cmp.isEquality() || (!C2->isNegative() && C1->isNonNegative())) {
-      Value *NewAnd = Builder->CreateAnd(
-          Cast->getOperand(0), ConstantExpr::getZExt(C2, Cast->getSrcTy()));
-      NewAnd->takeName(And);
-      return new ICmpInst(Cmp.getPredicate(), NewAnd,
-                          ConstantExpr::getZExt(RHS, Cast->getSrcTy()));
+  // If the LHS is an 'and' of a truncate and we can widen the and/compare to
+  // the input width without changing the value produced, eliminate the cast:
+  //
+  // icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1'
+  //
+  // We can do this transformation if the constants do not have their sign bits
+  // set or if it is an equality comparison. Extending a relational comparison
+  // when we're checking the sign bit would not work.
+  Value *W;
+  if (match(And->getOperand(0), m_Trunc(m_Value(W))) &&
+      (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) {
+    // TODO: Is this a good transform for vectors? Wider types may reduce
+    // throughput. Should this transform be limited (even for scalars) by using
+    // ShouldChangeType()?
+    if (!Cmp.getType()->isVectorTy()) {
+      Type *WideType = W->getType();
+      unsigned WideScalarBits = WideType->getScalarSizeInBits();
+      Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits));
+      Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits));
+      Value *NewAnd = Builder->CreateAnd(W, ZextC2, And->getName());
+      return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1);
     }
   }
 
   if (Instruction *I = foldICmpAndShift(Cmp, And, C1))
     return I;
 
+  // FIXME: This check restricts all folds under here to scalar types.
+  ConstantInt *RHS = dyn_cast<ConstantInt>(Cmp.getOperand(1));
+  if (!RHS)
+    return nullptr;
+
   // (icmp pred (and (or (lshr A, B), A), 1), 0) -->
-  //    (icmp pred (and A, (or (shl 1, B), 1), 0))
+  // (icmp pred (and A, (or (shl 1, B), 1), 0))
   //
   // iff pred isn't signed
   {
@@ -1573,7 +1581,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
   // Replace ((X & C2) > C1) with ((X & C2) != 0), if any bit set in (X & C2)
   // will produce a result greater than C1.
   if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) {
-    unsigned NTZ = C2->getValue().countTrailingZeros();
+    unsigned NTZ = C2->countTrailingZeros();
     if ((NTZ < C2->getBitWidth()) &&
         APInt::getOneBitSet(C2->getBitWidth(), NTZ).ugt(*C1))
       return new ICmpInst(ICmpInst::ICMP_NE, And,