[InstSimplify] Extend simplifications for `(icmp ({z|s}ext X), C)` where `C` is vector
authorNoah Goldstein <goldstein.w.n@gmail.com>
Mon, 3 Apr 2023 04:40:08 +0000 (23:40 -0500)
committerNoah Goldstein <goldstein.w.n@gmail.com>
Mon, 3 Apr 2023 16:04:57 +0000 (11:04 -0500)
Previous logic only applied for `ConstantInt` which misses all vector
cases. New code works for splat/non-splat vectors as well. No change
to the underlying simplifications.

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D147275

llvm/lib/Analysis/InstructionSimplify.cpp
llvm/test/Transforms/InstSimplify/vec-icmp-of-cast.ll

index eaf0af9..b82b0e7 100644 (file)
@@ -3818,22 +3818,27 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       }
       // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended
       // too.  If not, then try to deduce the result of the comparison.
-      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+      else if (match(RHS, m_ImmConstant())) {
+        Constant *C = dyn_cast<Constant>(RHS);
+        assert(C != nullptr);
+
         // Compute the constant that would happen if we truncated to SrcTy then
         // reextended to DstTy.
-        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+        Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy);
         Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy);
+        Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C);
 
-        // If the re-extended constant didn't change then this is effectively
-        // also a case of comparing two zero-extended values.
-        if (RExt == CI && MaxRecurse)
+        // If the re-extended constant didn't change any of the elements then
+        // this is effectively also a case of comparing two zero-extended
+        // values.
+        if (AnyEq->isAllOnesValue() && MaxRecurse)
           if (Value *V = simplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred),
                                           SrcOp, Trunc, Q, MaxRecurse - 1))
             return V;
 
         // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit
         // there.  Use this to work out the result of the comparison.
-        if (RExt != CI) {
+        if (AnyEq->isNullValue()) {
           switch (Pred) {
           default:
             llvm_unreachable("Unknown ICmp predicate!");
@@ -3841,26 +3846,23 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
           case ICmpInst::ICMP_EQ:
           case ICmpInst::ICMP_UGT:
           case ICmpInst::ICMP_UGE:
-            return ConstantInt::getFalse(CI->getContext());
+            return Constant::getNullValue(ITy);
 
           case ICmpInst::ICMP_NE:
           case ICmpInst::ICMP_ULT:
           case ICmpInst::ICMP_ULE:
-            return ConstantInt::getTrue(CI->getContext());
+            return Constant::getAllOnesValue(ITy);
 
           // LHS is non-negative.  If RHS is negative then LHS >s LHS.  If RHS
           // is non-negative then LHS <s RHS.
           case ICmpInst::ICMP_SGT:
           case ICmpInst::ICMP_SGE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getTrue(CI->getContext())
-                       : ConstantInt::getFalse(CI->getContext());
-
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C,
+                                         Constant::getNullValue(C->getType()));
           case ICmpInst::ICMP_SLT:
           case ICmpInst::ICMP_SLE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getFalse(CI->getContext())
-                       : ConstantInt::getTrue(CI->getContext());
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C,
+                                         Constant::getNullValue(C->getType()));
           }
         }
       }
@@ -3887,42 +3889,44 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       }
       // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended
       // too.  If not, then try to deduce the result of the comparison.
-      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
+      else if (match(RHS, m_ImmConstant())) {
+        Constant *C = dyn_cast<Constant>(RHS);
+        assert(C != nullptr);
+
         // Compute the constant that would happen if we truncated to SrcTy then
         // reextended to DstTy.
-        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy);
+        Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy);
         Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy);
+        Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C);
 
         // If the re-extended constant didn't change then this is effectively
         // also a case of comparing two sign-extended values.
-        if (RExt == CI && MaxRecurse)
+        if (AnyEq->isAllOnesValue() && MaxRecurse)
           if (Value *V =
                   simplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse - 1))
             return V;
 
         // Otherwise the upper bits of LHS are all equal, while RHS has varying
         // bits there.  Use this to work out the result of the comparison.
-        if (RExt != CI) {
+        if (AnyEq->isNullValue()) {
           switch (Pred) {
           default:
             llvm_unreachable("Unknown ICmp predicate!");
           case ICmpInst::ICMP_EQ:
-            return ConstantInt::getFalse(CI->getContext());
+            return Constant::getNullValue(ITy);
           case ICmpInst::ICMP_NE:
-            return ConstantInt::getTrue(CI->getContext());
+            return Constant::getAllOnesValue(ITy);
 
           // If RHS is non-negative then LHS <s RHS.  If RHS is negative then
           // LHS >s RHS.
           case ICmpInst::ICMP_SGT:
           case ICmpInst::ICMP_SGE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getTrue(CI->getContext())
-                       : ConstantInt::getFalse(CI->getContext());
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C,
+                                         Constant::getNullValue(C->getType()));
           case ICmpInst::ICMP_SLT:
           case ICmpInst::ICMP_SLE:
-            return CI->getValue().isNegative()
-                       ? ConstantInt::getFalse(CI->getContext())
-                       : ConstantInt::getTrue(CI->getContext());
+            return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C,
+                                         Constant::getNullValue(C->getType()));
 
           // If LHS is non-negative then LHS <u RHS.  If LHS is negative then
           // LHS >u RHS.
index d3240d6..4acf2fb 100644 (file)
@@ -3,9 +3,7 @@
 
 define <2 x i1> @icmp_eq_zext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_eq_zext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[XEXT]], <i32 511, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp eq <2 x i32> %xext, <i32 511, i32 1234>
@@ -14,9 +12,7 @@ define <2 x i1> @icmp_eq_zext_is_false(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ugt_zext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ugt_zext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ugt <2 x i32> %xext, <i32 256, i32 1234>
@@ -36,9 +32,7 @@ define <2 x i1> @icmp_ugt_zext_todo_off_by1(<2 x i8> %x) {
 
 define <2 x i1> @icmp_uge_zext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_uge_zext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp uge <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp uge <2 x i32> %xext, <i32 256, i32 1234>
@@ -69,9 +63,7 @@ define <2 x i1> @icmp_eq_zext_unused(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ne_zext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ne_zext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ne <2 x i32> %xext, <i32 256, i32 1234>
@@ -80,9 +72,7 @@ define <2 x i1> @icmp_ne_zext_is_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ult_zext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ult_zext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult <2 x i32> [[XEXT]], <i32 256, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ult <2 x i32> %xext, <i32 256, i32 1234>
@@ -91,9 +81,7 @@ define <2 x i1> @icmp_ult_zext_is_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ule_zext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ule_zext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule <2 x i32> [[XEXT]], <i32 256, i32 -1>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp ule <2 x i32> %xext, <i32 256, i32 -1>
@@ -124,9 +112,7 @@ define <2 x i1> @icmp_ne_zext_unused(<2 x i8> %x) {
 
 define <2 x i1> @icmp_sge_zext_is_false_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_sge_zext_is_false_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sge <2 x i32> [[XEXT]], <i32 257, i32 -450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 false, i1 true>
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp sge <2 x i32> %xext, <i32 257, i32 -450>
@@ -135,9 +121,7 @@ define <2 x i1> @icmp_sge_zext_is_false_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_sle_zext_is_false_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_sle_zext_is_false_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sle <2 x i32> [[XEXT]], <i32 -256, i32 -450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = zext <2 x i8> %x to <2 x i32>
   %cmp = icmp sle <2 x i32> %xext, <i32 -256, i32 -450>
@@ -146,9 +130,7 @@ define <2 x i1> @icmp_sle_zext_is_false_false(<2 x i8> %x) {
 
 define <2 x i1> @icmp_eq_sext_is_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_eq_sext_is_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[XEXT]], <i32 255, i32 129>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp eq <2 x i32> %xext, <i32 255, i32 129>
@@ -168,9 +150,7 @@ define <2 x i1> @icmp_eq_sext_fail(<2 x i8> %x) {
 
 define <2 x i1> @icmp_ne_sext_is_true(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_ne_sext_is_true(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i32> [[XEXT]], <i32 199, i32 1234>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 true>
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp ne <2 x i32> %xext, <i32 199, i32 1234>
@@ -179,9 +159,7 @@ define <2 x i1> @icmp_ne_sext_is_true(<2 x i8> %x) {
 
 define <2 x i1> @icmp_sgt_sext_is_true_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_sgt_sext_is_true_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt <2 x i32> [[XEXT]], <i32 -250, i32 450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 false>
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp sgt <2 x i32> %xext, <i32 -250, i32 450>
@@ -190,9 +168,7 @@ define <2 x i1> @icmp_sgt_sext_is_true_false(<2 x i8> %x) {
 
 define <2 x i1> @icmp_slt_sext_is_true_false(<2 x i8> %x) {
 ; CHECK-LABEL: @icmp_slt_sext_is_true_false(
-; CHECK-NEXT:    [[XEXT:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt <2 x i32> [[XEXT]], <i32 257, i32 -450>
-; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+; CHECK-NEXT:    ret <2 x i1> <i1 true, i1 false>
 ;
   %xext = sext <2 x i8> %x to <2 x i32>
   %cmp = icmp slt <2 x i32> %xext, <i32 257, i32 -450>