[InstCombine] try to fold icmp with mismatched extended operands
authorSanjay Patel <spatel@rotateright.com>
Tue, 26 Apr 2022 18:22:16 +0000 (14:22 -0400)
committerSanjay Patel <spatel@rotateright.com>
Tue, 26 Apr 2022 18:26:36 +0000 (14:26 -0400)
If a value is known to be non-negative and zexted,
that's the same thing as sexted.

So for the purpose of looking past the casts with
an icmp, treat it as if it was a sext:
https://alive2.llvm.org/ce/z/_BDsGV

This is necessary, but not enough to solve the
motivating problem:
https://github.com/llvm/llvm-project/issues/55013

Differential Revision: https://reviews.llvm.org/D124419

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/lib/Transforms/InstCombine/InstCombineInternal.h
llvm/test/Transforms/InstCombine/icmp-ext-ext.ll

index 6c70299..4a62afc 100644 (file)
@@ -4706,8 +4706,7 @@ static Instruction *foldICmpWithTrunc(ICmpInst &ICmp,
   return nullptr;
 }
 
-static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp,
-                                           InstCombiner::BuilderTy &Builder) {
+Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
   assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0");
   auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0));
   Value *X;
@@ -4716,25 +4715,37 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp,
 
   bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt;
   bool IsSignedCmp = ICmp.isSigned();
-  if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) {
-    // If the signedness of the two casts doesn't agree (i.e. one is a sext
-    // and the other is a zext), then we can't handle this.
-    // TODO: This is too strict. We can handle some predicates (equality?).
-    if (CastOp0->getOpcode() != CastOp1->getOpcode())
-      return nullptr;
+
+  // icmp Pred (ext X), (ext Y)
+  Value *Y;
+  if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) {
+    bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0));
+    bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1));
+
+    // If we have mismatched casts, treat the zext of a non-negative source as
+    // a sext to simulate matching casts. Otherwise, we are done.
+    // TODO: Can we handle some predicates (equality) without non-negative?
+    if (IsZext0 != IsZext1) {
+      if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) ||
+          (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT)))
+        IsSignedExt = true;
+      else
+        return nullptr;
+    }
 
     // Not an extension from the same type?
-    Value *Y = CastOp1->getOperand(0);
     Type *XTy = X->getType(), *YTy = Y->getType();
     if (XTy != YTy) {
       // One of the casts must have one use because we are creating a new cast.
-      if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse())
+      if (!ICmp.getOperand(0)->hasOneUse() && !ICmp.getOperand(1)->hasOneUse())
         return nullptr;
       // Extend the narrower operand to the type of the wider operand.
+      CastInst::CastOps CastOpcode =
+          IsSignedExt ? Instruction::SExt : Instruction::ZExt;
       if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits())
-        X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy);
+        X = Builder.CreateCast(CastOpcode, X, YTy);
       else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits())
-        Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy);
+        Y = Builder.CreateCast(CastOpcode, Y, XTy);
       else
         return nullptr;
     }
@@ -4852,7 +4863,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
   if (Instruction *R = foldICmpWithTrunc(ICmp, Builder))
     return R;
 
-  return foldICmpWithZextOrSext(ICmp, Builder);
+  return foldICmpWithZextOrSext(ICmp);
 }
 
 static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {
index 432cd8a..3e5fc3e 100644 (file)
@@ -661,7 +661,8 @@ public:
                                     Constant *RHSC);
   Instruction *foldICmpAddOpConst(Value *X, const APInt &C,
                                   ICmpInst::Predicate Pred);
-  Instruction *foldICmpWithCastOp(ICmpInst &ICI);
+  Instruction *foldICmpWithCastOp(ICmpInst &ICmp);
+  Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp);
 
   Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
   Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
index c0ff526..42472bb 100644 (file)
@@ -250,9 +250,7 @@ define i1 @sext_zext_uge_op0_wide(i16 %x, i8 %y) {
 define i1 @zext_sext_sgt_known_nonneg(i8 %x, i8 %y) {
 ; CHECK-LABEL: @zext_sext_sgt_known_nonneg(
 ; CHECK-NEXT:    [[N:%.*]] = udiv i8 127, [[X:%.*]]
-; CHECK-NEXT:    [[A:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[B:%.*]] = sext i8 [[Y:%.*]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i8 [[N]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %n = udiv i8 127, %x
@@ -265,9 +263,7 @@ define i1 @zext_sext_sgt_known_nonneg(i8 %x, i8 %y) {
 define i1 @zext_sext_ugt_known_nonneg(i8 %x, i8 %y) {
 ; CHECK-LABEL: @zext_sext_ugt_known_nonneg(
 ; CHECK-NEXT:    [[N:%.*]] = and i8 [[X:%.*]], 127
-; CHECK-NEXT:    [[A:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[B:%.*]] = sext i8 [[Y:%.*]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i8 [[N]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %n = and i8 %x, 127
@@ -280,9 +276,7 @@ define i1 @zext_sext_ugt_known_nonneg(i8 %x, i8 %y) {
 define i1 @zext_sext_eq_known_nonneg(i8 %x, i8 %y) {
 ; CHECK-LABEL: @zext_sext_eq_known_nonneg(
 ; CHECK-NEXT:    [[N:%.*]] = lshr i8 [[X:%.*]], 1
-; CHECK-NEXT:    [[A:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[B:%.*]] = sext i8 [[Y:%.*]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[A]], [[B]]
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i8 [[N]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %n = lshr i8 %x, 1
@@ -295,9 +289,8 @@ define i1 @zext_sext_eq_known_nonneg(i8 %x, i8 %y) {
 define i1 @zext_sext_sle_known_nonneg_op0_narrow(i8 %x, i16 %y) {
 ; CHECK-LABEL: @zext_sext_sle_known_nonneg_op0_narrow(
 ; CHECK-NEXT:    [[N:%.*]] = and i8 [[X:%.*]], 12
-; CHECK-NEXT:    [[A:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[B:%.*]] = sext i16 [[Y:%.*]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp sle i32 [[A]], [[B]]
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i8 [[N]] to i16
+; CHECK-NEXT:    [[C:%.*]] = icmp sle i16 [[TMP1]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %n = and i8 %x, 12
@@ -310,9 +303,8 @@ define i1 @zext_sext_sle_known_nonneg_op0_narrow(i8 %x, i16 %y) {
 define i1 @zext_sext_ule_known_nonneg_op0_wide(i9 %x, i8 %y) {
 ; CHECK-LABEL: @zext_sext_ule_known_nonneg_op0_wide(
 ; CHECK-NEXT:    [[N:%.*]] = urem i9 [[X:%.*]], 254
-; CHECK-NEXT:    [[A:%.*]] = zext i9 [[N]] to i32
-; CHECK-NEXT:    [[B:%.*]] = sext i8 [[Y:%.*]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp ule i32 [[A]], [[B]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sext i8 [[Y:%.*]] to i9
+; CHECK-NEXT:    [[C:%.*]] = icmp ule i9 [[N]], [[TMP1]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %n = urem i9 %x, 254
@@ -324,10 +316,8 @@ define i1 @zext_sext_ule_known_nonneg_op0_wide(i9 %x, i8 %y) {
 
 define i1 @sext_zext_slt_known_nonneg(i8 %x, i8 %y) {
 ; CHECK-LABEL: @sext_zext_slt_known_nonneg(
-; CHECK-NEXT:    [[A:%.*]] = sext i8 [[X:%.*]] to i32
 ; CHECK-NEXT:    [[N:%.*]] = and i8 [[Y:%.*]], 126
-; CHECK-NEXT:    [[B:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[A]], [[B]]
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i8 [[N]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %a = sext i8 %x to i32
@@ -339,10 +329,8 @@ define i1 @sext_zext_slt_known_nonneg(i8 %x, i8 %y) {
 
 define i1 @sext_zext_ult_known_nonneg(i8 %x, i8 %y) {
 ; CHECK-LABEL: @sext_zext_ult_known_nonneg(
-; CHECK-NEXT:    [[A:%.*]] = sext i8 [[X:%.*]] to i32
 ; CHECK-NEXT:    [[N:%.*]] = lshr i8 [[Y:%.*]], 6
-; CHECK-NEXT:    [[B:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[A]], [[B]]
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i8 [[N]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %a = sext i8 %x to i32
@@ -354,10 +342,8 @@ define i1 @sext_zext_ult_known_nonneg(i8 %x, i8 %y) {
 
 define i1 @sext_zext_ne_known_nonneg(i8 %x, i8 %y) {
 ; CHECK-LABEL: @sext_zext_ne_known_nonneg(
-; CHECK-NEXT:    [[A:%.*]] = sext i8 [[X:%.*]] to i32
 ; CHECK-NEXT:    [[N:%.*]] = udiv i8 [[Y:%.*]], 6
-; CHECK-NEXT:    [[B:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp ne i32 [[A]], [[B]]
+; CHECK-NEXT:    [[C:%.*]] = icmp ne i8 [[N]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %a = sext i8 %x to i32
@@ -369,10 +355,9 @@ define i1 @sext_zext_ne_known_nonneg(i8 %x, i8 %y) {
 
 define <2 x i1> @sext_zext_sge_known_nonneg_op0_narrow(<2 x i5> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @sext_zext_sge_known_nonneg_op0_narrow(
-; CHECK-NEXT:    [[A:%.*]] = sext <2 x i5> [[X:%.*]] to <2 x i32>
 ; CHECK-NEXT:    [[N:%.*]] = mul nsw <2 x i8> [[Y:%.*]], [[Y]]
-; CHECK-NEXT:    [[B:%.*]] = zext <2 x i8> [[N]] to <2 x i32>
-; CHECK-NEXT:    [[C:%.*]] = icmp sge <2 x i32> [[A]], [[B]]
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <2 x i5> [[X:%.*]] to <2 x i8>
+; CHECK-NEXT:    [[C:%.*]] = icmp sle <2 x i8> [[N]], [[TMP1]]
 ; CHECK-NEXT:    ret <2 x i1> [[C]]
 ;
   %a = sext <2 x i5> %x to <2 x i32>
@@ -384,10 +369,9 @@ define <2 x i1> @sext_zext_sge_known_nonneg_op0_narrow(<2 x i5> %x, <2 x i8> %y)
 
 define i1 @sext_zext_uge_known_nonneg_op0_wide(i16 %x, i8 %y) {
 ; CHECK-LABEL: @sext_zext_uge_known_nonneg_op0_wide(
-; CHECK-NEXT:    [[A:%.*]] = sext i16 [[X:%.*]] to i32
 ; CHECK-NEXT:    [[N:%.*]] = and i8 [[Y:%.*]], 12
-; CHECK-NEXT:    [[B:%.*]] = zext i8 [[N]] to i32
-; CHECK-NEXT:    [[C:%.*]] = icmp uge i32 [[A]], [[B]]
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i8 [[N]] to i16
+; CHECK-NEXT:    [[C:%.*]] = icmp ule i16 [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[C]]
 ;
   %a = sext i16 %x to i32