[InstCombine] Add support for max(a,b) + min(a,b) => a + b.
authorSerguei Katkov <serguei.katkov@azul.com>
Fri, 31 Mar 2023 02:35:06 +0000 (09:35 +0700)
committerSerguei Katkov <serguei.katkov@azul.com>
Fri, 7 Apr 2023 03:24:07 +0000 (10:24 +0700)
The same optimization for
  max(a,b) * min(a,b) => a * b
is added.

Correctness check:
uadd: https://alive2.llvm.org/ce/z/2rXDek
sadd: https://alive2.llvm.org/ce/z/zNu_er
uadd + nuw/nsw: https://alive2.llvm.org/ce/z/EaiNjB
sadd + nuw/nsw: https://alive2.llvm.org/ce/z/w_2Nrs

umul: https://alive2.llvm.org/ce/z/dgXRLr
smul: https://alive2.llvm.org/ce/z/hBjGzz
umul + nuw/nsw: https://alive2.llvm.org/ce/z/EaiNjB
smul + nuw/nsw: https://alive2.llvm.org/ce/z/87MNeS

Reviewed By: goldstein.w.n
Differential Revision: https://reviews.llvm.org/D147296

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
llvm/test/Transforms/InstCombine/add-min-max.ll
llvm/test/Transforms/InstCombine/mul-min-max.ll

index 2c2b767..7dd9fc4 100644 (file)
@@ -1554,6 +1554,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *Ashr = foldAddToAshr(I))
     return Ashr;
 
+  // min(A, B) + max(A, B) => A + B.
+  if (match(&I,
+            match_combine_or(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
+                                     m_c_SMin(m_Deferred(A), m_Deferred(B))),
+                             m_c_Add(m_UMax(m_Value(A), m_Value(B)),
+                                     m_c_UMin(m_Deferred(A), m_Deferred(B))))))
+    return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
+
   // TODO(jingyue): Consider willNotOverflowSignedAdd and
   // willNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
index 5768f71..a3baff3 100644 (file)
@@ -471,6 +471,14 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Instruction *Ext = narrowMathIfNoOverflow(I))
     return Ext;
 
+  // min(X, Y) * max(X, Y) => X * Y.
+  if (match(&I,
+            match_combine_or(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
+                                     m_c_SMin(m_Deferred(X), m_Deferred(Y))),
+                             m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
+                                     m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
+    return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);
+
   bool Changed = false;
   if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) {
     Changed = true;
index d077d18..2117a55 100644 (file)
@@ -9,9 +9,7 @@ declare i32 @llvm.umin.i32(i32 %a, i32 %b)
 define i32 @uadd_min_max(i32 %a, i32 %b) {
 ; CHECK-LABEL: @uadd_min_max(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = add i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = add i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -24,9 +22,7 @@ entry:
 define i32 @uadd_min_max_comm(i32 %a, i32 %b) {
 ; CHECK-LABEL: @uadd_min_max_comm(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[B:%.*]], i32 [[A:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = add i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = add i32 [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -39,9 +35,7 @@ entry:
 define i32 @uadd_min_max_nuw_nsw(i32 %a, i32 %b) {
 ; CHECK-LABEL: @uadd_min_max_nuw_nsw(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = add nuw nsw i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = add nuw nsw i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -54,9 +48,7 @@ entry:
 define i32 @sadd_min_max(i32 %a, i32 %b) {
 ; CHECK-LABEL: @sadd_min_max(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = add i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = add i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -69,9 +61,7 @@ entry:
 define i32 @sadd_min_max_comm(i32 %a, i32 %b) {
 ; CHECK-LABEL: @sadd_min_max_comm(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[B:%.*]], i32 [[A:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = add i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = add i32 [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -84,9 +74,7 @@ entry:
 define i32 @sadd_min_max_nuw_nsw(i32 %a, i32 %b) {
 ; CHECK-LABEL: @sadd_min_max_nuw_nsw(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = add nuw nsw i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = add nuw nsw i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
index e6808ff..fce6b95 100644 (file)
@@ -9,9 +9,7 @@ declare i32 @llvm.umin.i32(i32 %a, i32 %b)
 define i32 @umul_min_max(i32 %a, i32 %b) {
 ; CHECK-LABEL: @umul_min_max(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -24,9 +22,7 @@ entry:
 define i32 @umul_min_max_comm(i32 %a, i32 %b) {
 ; CHECK-LABEL: @umul_min_max_comm(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[B:%.*]], i32 [[A:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -39,9 +35,7 @@ entry:
 define i32 @umul_min_max_nuw_nsw(i32 %a, i32 %b) {
 ; CHECK-LABEL: @umul_min_max_nuw_nsw(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = mul nuw nsw i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = mul nuw nsw i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -54,9 +48,7 @@ entry:
 define i32 @smul_min_max(i32 %a, i32 %b) {
 ; CHECK-LABEL: @smul_min_max(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -69,9 +61,7 @@ entry:
 define i32 @smul_min_max_comm(i32 %a, i32 %b) {
 ; CHECK-LABEL: @smul_min_max_comm(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[B:%.*]], i32 [[A:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = mul i32 [[B:%.*]], [[A:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -84,9 +74,7 @@ entry:
 define i32 @smul_min_max_nuw_nsw(i32 %a, i32 %b) {
 ; CHECK-LABEL: @smul_min_max_nuw_nsw(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]])
-; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT:    [[RES:%.*]] = mul nuw nsw i32 [[MIN]], [[MAX]]
+; CHECK-NEXT:    [[RES:%.*]] = mul nuw nsw i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry: