From b86a06ef284f2637bef89bf5bb20157a8b195568 Mon Sep 17 00:00:00 2001 From: Serguei Katkov Date: Fri, 31 Mar 2023 09:35:06 +0700 Subject: [PATCH] [InstCombine] Add support for max(a,b) + min(a,b) => a + b. The same optimization for max(a,b) * min(a,b) => a * b is added. Correctness check: uadd: https://alive2.llvm.org/ce/z/2rXDek sadd: https://alive2.llvm.org/ce/z/zNu_er uadd + nuw/nsw: https://alive2.llvm.org/ce/z/EaiNjB sadd + nuw/nsw: https://alive2.llvm.org/ce/z/w_2Nrs umul: https://alive2.llvm.org/ce/z/dgXRLr smul: https://alive2.llvm.org/ce/z/hBjGzz umul + nuw/nsw: https://alive2.llvm.org/ce/z/EaiNjB smul + nuw/nsw: https://alive2.llvm.org/ce/z/87MNeS Reviewed By: goldstein.w.n Differential Revision: https://reviews.llvm.org/D147296 --- .../Transforms/InstCombine/InstCombineAddSub.cpp | 8 ++++++++ .../InstCombine/InstCombineMulDivRem.cpp | 8 ++++++++ llvm/test/Transforms/InstCombine/add-min-max.ll | 24 ++++++---------------- llvm/test/Transforms/InstCombine/mul-min-max.ll | 24 ++++++---------------- 4 files changed, 28 insertions(+), 36 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 2c2b767..7dd9fc4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1554,6 +1554,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *Ashr = foldAddToAshr(I)) return Ashr; + // min(A, B) + max(A, B) => A + B. + if (match(&I, + match_combine_or(m_c_Add(m_SMax(m_Value(A), m_Value(B)), + m_c_SMin(m_Deferred(A), m_Deferred(B))), + m_c_Add(m_UMax(m_Value(A), m_Value(B)), + m_c_UMin(m_Deferred(A), m_Deferred(B)))))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I); + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 5768f71..a3baff3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -471,6 +471,14 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; + // min(X, Y) * max(X, Y) => X * Y. + if (match(&I, + match_combine_or(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)), + m_c_SMin(m_Deferred(X), m_Deferred(Y))), + m_c_Mul(m_UMax(m_Value(X), m_Value(Y)), + m_c_UMin(m_Deferred(X), m_Deferred(Y)))))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I); + bool Changed = false; if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; diff --git a/llvm/test/Transforms/InstCombine/add-min-max.ll b/llvm/test/Transforms/InstCombine/add-min-max.ll index d077d18..2117a55 100644 --- a/llvm/test/Transforms/InstCombine/add-min-max.ll +++ b/llvm/test/Transforms/InstCombine/add-min-max.ll @@ -9,9 +9,7 @@ declare i32 @llvm.umin.i32(i32 %a, i32 %b) define i32 @uadd_min_max(i32 %a, i32 %b) { ; CHECK-LABEL: @uadd_min_max( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = add i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = add i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -24,9 +22,7 @@ entry: define i32 @uadd_min_max_comm(i32 %a, i32 %b) { ; CHECK-LABEL: @uadd_min_max_comm( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[B:%.*]], i32 [[A:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = add i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = add i32 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -39,9 +35,7 @@ entry: define i32 @uadd_min_max_nuw_nsw(i32 %a, i32 %b) { ; CHECK-LABEL: @uadd_min_max_nuw_nsw( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -54,9 +48,7 @@ entry: define i32 @sadd_min_max(i32 %a, i32 %b) { ; CHECK-LABEL: @sadd_min_max( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = add i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = add i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -69,9 +61,7 @@ entry: define i32 @sadd_min_max_comm(i32 %a, i32 %b) { ; CHECK-LABEL: @sadd_min_max_comm( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[B:%.*]], i32 [[A:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = add i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = add i32 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -84,9 +74,7 @@ entry: define i32 @sadd_min_max_nuw_nsw(i32 %a, i32 %b) { ; CHECK-LABEL: @sadd_min_max_nuw_nsw( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: diff --git a/llvm/test/Transforms/InstCombine/mul-min-max.ll b/llvm/test/Transforms/InstCombine/mul-min-max.ll index e6808ff..fce6b95 100644 --- a/llvm/test/Transforms/InstCombine/mul-min-max.ll +++ b/llvm/test/Transforms/InstCombine/mul-min-max.ll @@ -9,9 +9,7 @@ declare i32 @llvm.umin.i32(i32 %a, i32 %b) define i32 @umul_min_max(i32 %a, i32 %b) { ; CHECK-LABEL: @umul_min_max( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = mul i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = mul i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -24,9 +22,7 @@ entry: define i32 @umul_min_max_comm(i32 %a, i32 %b) { ; CHECK-LABEL: @umul_min_max_comm( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[B:%.*]], i32 [[A:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = mul i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = mul i32 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -39,9 +35,7 @@ entry: define i32 @umul_min_max_nuw_nsw(i32 %a, i32 %b) { ; CHECK-LABEL: @umul_min_max_nuw_nsw( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = mul nuw nsw i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = mul nuw nsw i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -54,9 +48,7 @@ entry: define i32 @smul_min_max(i32 %a, i32 %b) { ; CHECK-LABEL: @smul_min_max( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = mul i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = mul i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -69,9 +61,7 @@ entry: define i32 @smul_min_max_comm(i32 %a, i32 %b) { ; CHECK-LABEL: @smul_min_max_comm( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[B:%.*]], i32 [[A:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = mul i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = mul i32 [[B:%.*]], [[A:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: @@ -84,9 +74,7 @@ entry: define i32 @smul_min_max_nuw_nsw(i32 %a, i32 %b) { ; CHECK-LABEL: @smul_min_max_nuw_nsw( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[MAX:%.*]] = call i32 @llvm.smax.i32(i32 [[A:%.*]], i32 [[B:%.*]]) -; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[A]], i32 [[B]]) -; CHECK-NEXT: [[RES:%.*]] = mul nuw nsw i32 [[MIN]], [[MAX]] +; CHECK-NEXT: [[RES:%.*]] = mul nuw nsw i32 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; entry: -- 2.7.4