[InstCombine] fold icmp equality with udiv and large constant
authorSanjay Patel <spatel@rotateright.com>
Thu, 26 May 2022 12:23:36 +0000 (08:23 -0400)
committerSanjay Patel <spatel@rotateright.com>
Thu, 26 May 2022 13:08:47 +0000 (09:08 -0400)
With large compare constant:
(X u/ Y) == C --> (X == C) && (Y == 1)
(X u/ Y) != C --> (X != C) || (Y != 1)

https://alive2.llvm.org/ce/z/EhKwh6

There are various potential missing icmp (div) transforms shown here:
https://github.com/llvm/llvm-project/issues/55695

This is a generalization for part of the udiv + equality.
I didn't check in detail, but some of those may only make sense as
codegen transforms.

This results in one extra instruction in IR, but it is better for
analysis, and looks much better in codegen on all targets that I tried.

Differential Revision: https://reviews.llvm.org/D126410

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/test/Transforms/InstCombine/icmp-div-constant.ll

index f885859..9f5f65c 100644 (file)
@@ -2383,26 +2383,41 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
 Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
                                                     BinaryOperator *UDiv,
                                                     const APInt &C) {
+  ICmpInst::Predicate Pred = Cmp.getPredicate();
+  Value *X = UDiv->getOperand(0);
+  Value *Y = UDiv->getOperand(1);
+  Type *Ty = UDiv->getType();
+
+  // If the compare constant is bigger than UMAX/2 (negative), there's only one
+  // pair of values that satisfies an equality check, so eliminate the division:
+  // (X u/ Y) == C --> (X == C) && (Y == 1)
+  // (X u/ Y) != C --> (X != C) || (Y != 1)
+  if (Cmp.isEquality() && UDiv->hasOneUse() && C.isSignBitSet()) {
+    Value *XBig = Builder.CreateICmp(Pred, X, ConstantInt::get(Ty, C));
+    Value *YOne = Builder.CreateICmp(Pred, Y, ConstantInt::get(Ty, 1));
+    auto Logic = Pred == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or;
+    return BinaryOperator::Create(Logic, XBig, YOne);
+  }
+
   const APInt *C2;
-  if (!match(UDiv->getOperand(0), m_APInt(C2)))
+  if (!match(X, m_APInt(C2)))
     return nullptr;
 
   assert(*C2 != 0 && "udiv 0, X should have been simplified already.");
 
   // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1))
-  Value *Y = UDiv->getOperand(1);
-  if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) {
+  if (Pred == ICmpInst::ICMP_UGT) {
     assert(!C.isMaxValue() &&
            "icmp ugt X, UINT_MAX should have been simplified already.");
     return new ICmpInst(ICmpInst::ICMP_ULE, Y,
-                        ConstantInt::get(Y->getType(), C2->udiv(C + 1)));
+                        ConstantInt::get(Ty, C2->udiv(C + 1)));
   }
 
   // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C)
-  if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) {
+  if (Pred == ICmpInst::ICMP_ULT) {
     assert(C != 0 && "icmp ult X, 0 should have been simplified already.");
     return new ICmpInst(ICmpInst::ICMP_UGT, Y,
-                        ConstantInt::get(Y->getType(), C2->udiv(C)));
+                        ConstantInt::get(Ty, C2->udiv(C)));
   }
 
   return nullptr;
index ff587dc..e9b26e6 100644 (file)
@@ -198,8 +198,9 @@ exit:
 
 define i1 @udiv_eq_umax(i8 %x, i8 %y) {
 ; CHECK-LABEL: @udiv_eq_umax(
-; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[D]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[X:%.*]], -1
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i8 [[Y:%.*]], 1
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %d = udiv i8 %x, %y
@@ -209,8 +210,9 @@ define i1 @udiv_eq_umax(i8 %x, i8 %y) {
 
 define <2 x i1> @udiv_ne_umax(<2 x i5> %x, <2 x i5> %y) {
 ; CHECK-LABEL: @udiv_ne_umax(
-; CHECK-NEXT:    [[D:%.*]] = udiv <2 x i5> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne <2 x i5> [[D]], <i5 -1, i5 -1>
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ne <2 x i5> [[X:%.*]], <i5 -1, i5 -1>
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne <2 x i5> [[Y:%.*]], <i5 1, i5 1>
+; CHECK-NEXT:    [[R:%.*]] = or <2 x i1> [[TMP1]], [[TMP2]]
 ; CHECK-NEXT:    ret <2 x i1> [[R]]
 ;
   %d = udiv <2 x i5> %x, %y
@@ -220,8 +222,9 @@ define <2 x i1> @udiv_ne_umax(<2 x i5> %x, <2 x i5> %y) {
 
 define i1 @udiv_eq_big(i8 %x, i8 %y) {
 ; CHECK-LABEL: @udiv_eq_big(
-; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[D]], -128
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[X:%.*]], -128
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i8 [[Y:%.*]], 1
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %d = udiv i8 %x, %y
@@ -231,8 +234,9 @@ define i1 @udiv_eq_big(i8 %x, i8 %y) {
 
 define i1 @udiv_ne_big(i8 %x, i8 %y) {
 ; CHECK-LABEL: @udiv_ne_big(
-; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[D]], -128
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ne i8 [[X:%.*]], -128
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i8 [[Y:%.*]], 1
+; CHECK-NEXT:    [[R:%.*]] = or i1 [[TMP1]], [[TMP2]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %d = udiv i8 %x, %y
@@ -240,6 +244,8 @@ define i1 @udiv_ne_big(i8 %x, i8 %y) {
   ret i1 %r
 }
 
+; negative test - must have negative compare constant
+
 define i1 @udiv_eq_not_big(i8 %x, i8 %y) {
 ; CHECK-LABEL: @udiv_eq_not_big(
 ; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
@@ -251,6 +257,8 @@ define i1 @udiv_eq_not_big(i8 %x, i8 %y) {
   ret i1 %r
 }
 
+; negative test - must be equality predicate
+
 define i1 @udiv_slt_umax(i8 %x, i8 %y) {
 ; CHECK-LABEL: @udiv_slt_umax(
 ; CHECK-NEXT:    [[D:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
@@ -262,6 +270,8 @@ define i1 @udiv_slt_umax(i8 %x, i8 %y) {
   ret i1 %r
 }
 
+; negative test - extra use
+
 define i1 @udiv_eq_umax_use(i32 %x, i32 %y) {
 ; CHECK-LABEL: @udiv_eq_umax_use(
 ; CHECK-NEXT:    [[D:%.*]] = udiv i32 [[X:%.*]], [[Y:%.*]]