From 6de1dbbd09c12abbec7eb187ffa1afbd47302dfa Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 12 Aug 2021 08:34:30 -0400 Subject: [PATCH] [InstCombine] factorize min/max intrinsic ops with common operand This is an adaptation of D41603 and another step on the way to canonicalizing to the intrinsic forms of min/max. See D98152 for status. --- .../Transforms/InstCombine/InstCombineCalls.cpp | 58 ++++++++++++++++++++++ .../Transforms/InstCombine/minmax-intrinsics.ll | 18 +++---- 2 files changed, 65 insertions(+), 11 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 210652e..bda8c25 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -795,6 +795,61 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, return SelectInst::Create(Cmp, ConstantInt::get(II->getType(), *C0), I1); } +/// Reduce a sequence of min/max intrinsics with a common operand. +static Instruction *factorizeMinMaxTree(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + auto *LHS = dyn_cast(II->getArgOperand(0)); + auto *RHS = dyn_cast(II->getArgOperand(1)); + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + if (!LHS || !RHS || LHS->getIntrinsicID() != MinMaxID || + RHS->getIntrinsicID() != MinMaxID || + (!LHS->hasOneUse() && !RHS->hasOneUse())) + return nullptr; + + Value *A = LHS->getArgOperand(0); + Value *B = LHS->getArgOperand(1); + Value *C = RHS->getArgOperand(0); + Value *D = RHS->getArgOperand(1); + + // Look for a common operand. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (LHS->hasOneUse()) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else { + assert(RHS->hasOneUse() && "Expected one-use operand"); + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + + Module *Mod = II->getModule(); + Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); + return CallInst::Create(MinMax, { MinMaxOp, ThirdOp }); +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. @@ -1056,6 +1111,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *R = FoldOpIntoSelect(*II, Sel)) return R; + if (Instruction *NewMinMax = factorizeMinMaxTree(II, Builder)) + return NewMinMax; + break; } case Intrinsic::bswap: { diff --git a/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll b/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll index faa95e8..ec57a79 100644 --- a/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/minmax-intrinsics.ll @@ -907,9 +907,8 @@ define <3 x i1> @umin_ne_zero2(<3 x i8> %a, <3 x i8> %b) { define i8 @smax(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @smax( -; CHECK-NEXT: [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) -; CHECK-NEXT: [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[Z:%.*]]) -; CHECK-NEXT: [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 [[M2]]) +; CHECK-NEXT: [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Z:%.*]]) +; CHECK-NEXT: [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M2]], i8 [[Y:%.*]]) ; CHECK-NEXT: ret i8 [[M3]] ; %m1 = call i8 @llvm.smax.i8(i8 %x, i8 %y) @@ -920,9 +919,8 @@ define i8 @smax(i8 %x, i8 %y, i8 %z) { define <3 x i8> @smin(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ; CHECK-LABEL: @smin( -; CHECK-NEXT: [[M1:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[Y:%.*]], <3 x i8> [[X:%.*]]) -; CHECK-NEXT: [[M2:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[X]], <3 x i8> [[Z:%.*]]) -; CHECK-NEXT: [[M3:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[M1]], <3 x i8> [[M2]]) +; CHECK-NEXT: [[M2:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[X:%.*]], <3 x i8> [[Z:%.*]]) +; CHECK-NEXT: [[M3:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[M2]], <3 x i8> [[Y:%.*]]) ; CHECK-NEXT: ret <3 x i8> [[M3]] ; %m1 = call <3 x i8> @llvm.smin.v3i8(<3 x i8> %y, <3 x i8> %x) @@ -935,8 +933,7 @@ define i8 @umax(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @umax( ; CHECK-NEXT: [[M1:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) ; CHECK-NEXT: call void @use(i8 [[M1]]) -; CHECK-NEXT: [[M2:%.*]] = call i8 @llvm.umax.i8(i8 [[Z:%.*]], i8 [[X]]) -; CHECK-NEXT: [[M3:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 [[M2]]) +; CHECK-NEXT: [[M3:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 [[Z:%.*]]) ; CHECK-NEXT: ret i8 [[M3]] ; %m1 = call i8 @llvm.umax.i8(i8 %x, i8 %y) @@ -948,10 +945,9 @@ define i8 @umax(i8 %x, i8 %y, i8 %z) { define i8 @umin(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @umin( -; CHECK-NEXT: [[M1:%.*]] = call i8 @llvm.umin.i8(i8 [[Y:%.*]], i8 [[X:%.*]]) -; CHECK-NEXT: [[M2:%.*]] = call i8 @llvm.umin.i8(i8 [[Z:%.*]], i8 [[X]]) +; CHECK-NEXT: [[M2:%.*]] = call i8 @llvm.umin.i8(i8 [[Z:%.*]], i8 [[X:%.*]]) ; CHECK-NEXT: call void @use(i8 [[M2]]) -; CHECK-NEXT: [[M3:%.*]] = call i8 @llvm.umin.i8(i8 [[M1]], i8 [[M2]]) +; CHECK-NEXT: [[M3:%.*]] = call i8 @llvm.umin.i8(i8 [[M2]], i8 [[Y:%.*]]) ; CHECK-NEXT: ret i8 [[M3]] ; %m1 = call i8 @llvm.umin.i8(i8 %y, i8 %x) -- 2.7.4