[InstCombine] Optimize compares with multiple selects as operands
authorTejas Joshi <TejasSanjay.Joshi@amd.com>
Fri, 26 May 2023 14:02:22 +0000 (16:02 +0200)
committerNikita Popov <npopov@redhat.com>
Fri, 26 May 2023 14:05:32 +0000 (16:05 +0200)
In case of a comparison with two select instructions having the same
condition, check whether one of the resulting branches can be simplified.
If so, just compare the other branch and select the appropriate result.
For example:

    %tmp1 = select i1 %cmp, i32 %y, i32 %x
    %tmp2 = select i1 %cmp, i32 %z, i32 %x
    %cmp2 = icmp slt i32 %tmp2, %tmp1

The icmp will result false for the false value of selects and the result
will depend upon the comparison of true values of selects if %cmp is
true. Thus, transform this into:

    %cmp = icmp slt i32 %y, %z
    %sel = select i1 %cond, i1 %cmp, i1 false

Differential Revision: https://reviews.llvm.org/D150360

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/test/Transforms/InstCombine/icmp-with-selects.ll

index e1a8073..462e65d 100644 (file)
@@ -6577,6 +6577,37 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
     if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I))
       return NI;
 
+  // In case of a comparison with two select instructions having the same
+  // condition, check whether one of the resulting branches can be simplified.
+  // If so, just compare the other branch and select the appropriate result.
+  // For example:
+  //   %tmp1 = select i1 %cmp, i32 %y, i32 %x
+  //   %tmp2 = select i1 %cmp, i32 %z, i32 %x
+  //   %cmp2 = icmp slt i32 %tmp2, %tmp1
+  // The icmp will result false for the false value of selects and the result
+  // will depend upon the comparison of true values of selects if %cmp is
+  // true. Thus, transform this into:
+  //   %cmp = icmp slt i32 %y, %z
+  //   %sel = select i1 %cond, i1 %cmp, i1 false
+  // This handles similar cases to transform.
+  {
+    Value *Cond, *A, *B, *C, *D;
+    if (match(Op0, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) &&
+        match(Op1, m_Select(m_Specific(Cond), m_Value(C), m_Value(D))) &&
+        (Op0->hasOneUse() || Op1->hasOneUse())) {
+      // Check whether comparison of TrueValues can be simplified
+      if (Value *Res = simplifyICmpInst(Pred, A, C, SQ)) {
+        Value *NewICMP = Builder.CreateICmp(Pred, B, D);
+        return SelectInst::Create(Cond, Res, NewICMP);
+      }
+      // Check whether comparison of FalseValues can be simplified
+      if (Value *Res = simplifyICmpInst(Pred, B, D, SQ)) {
+        Value *NewICMP = Builder.CreateICmp(Pred, A, C);
+        return SelectInst::Create(Cond, NewICMP, Res);
+      }
+    }
+  }
+
   // Try to optimize equality comparisons against alloca-based pointers.
   if (Op0->getType()->isPointerTy() && I.isEquality()) {
     assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
index 540ecce..9ee7c78 100644 (file)
@@ -7,10 +7,7 @@ define i1 @both_sides_fold_slt(i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @both_sides_fold_slt
 ; CHECK-SAME: (i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 1, i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 9, i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %cond1 = select i1 %cond, i32 1, i32 %param
@@ -23,10 +20,8 @@ define i1 @both_sides_fold_eq(i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @both_sides_fold_eq
 ; CHECK-SAME: (i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 1, i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 9, i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[COND6]], [[COND1]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[NOT_COND:%.*]] = xor i1 [[COND]], true
+; CHECK-NEXT:    ret i1 [[NOT_COND]]
 ;
 entry:
   %cond1 = select i1 %cond, i32 1, i32 %param
@@ -39,9 +34,8 @@ define i1 @one_side_fold_slt(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @one_side_fold_slt
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp slt i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[COND]], i1 [[TMP0]], i1 false
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -55,9 +49,9 @@ define i1 @one_side_fold_sgt(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @one_side_fold_sgt
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[PARAM]], i32 [[VAL1]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[PARAM]], i32 [[VAL2]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[NOT_COND:%.*]] = xor i1 [[COND]], true
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[NOT_COND]], i1 [[TMP0]], i1 false
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -71,9 +65,9 @@ define i1 @one_side_fold_eq(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @one_side_fold_eq
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[NOT_COND:%.*]] = xor i1 [[COND]], true
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[TMP0]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -120,9 +114,9 @@ define i1 @one_select_mult_use(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]]
 ; CHECK-NEXT:    call void @use(i32 [[COND1]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp slt i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[COND]], i1 [[TMP0]], i1 false
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -155,9 +149,8 @@ define <4 x i1> @fold_vector_ops(<4 x i32> %val1, <4 x i32> %val2, <4 x i32> %pa
 ; CHECK-LABEL: define <4 x i1> @fold_vector_ops
 ; CHECK-SAME: (<4 x i32> [[VAL1:%.*]], <4 x i32> [[VAL2:%.*]], <4 x i32> [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], <4 x i32> [[VAL1]], <4 x i32> [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], <4 x i32> [[VAL2]], <4 x i32> [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <4 x i32> [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq <4 x i32> [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[COND]], <4 x i1> [[TMP0]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>
 ; CHECK-NEXT:    ret <4 x i1> [[CMP]]
 ;
 entry:
@@ -171,9 +164,8 @@ define <8 x i1> @fold_vector_cond_ops(<8 x i32> %val1, <8 x i32> %val2, <8 x i32
 ; CHECK-LABEL: define <8 x i1> @fold_vector_cond_ops
 ; CHECK-SAME: (<8 x i32> [[VAL1:%.*]], <8 x i32> [[VAL2:%.*]], <8 x i32> [[PARAM:%.*]], <8 x i1> [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select <8 x i1> [[COND]], <8 x i32> [[VAL1]], <8 x i32> [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select <8 x i1> [[COND]], <8 x i32> [[VAL2]], <8 x i32> [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt <8 x i32> [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt <8 x i32> [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select <8 x i1> [[COND]], <8 x i1> [[TMP0]], <8 x i1> zeroinitializer
 ; CHECK-NEXT:    ret <8 x i1> [[CMP]]
 ;
 entry: