[InstCombine] push constant operand down/outside in sequence of min/max intrinsics
authorSanjay Patel <spatel@rotateright.com>
Thu, 17 Feb 2022 15:34:48 +0000 (10:34 -0500)
committerSanjay Patel <spatel@rotateright.com>
Thu, 17 Feb 2022 15:36:37 +0000 (10:36 -0500)
A generalization like this was suggested in D119754.
This is the inverse direction of D119851,
and we get all of the folds there plus the one that was missed.

There is precedence for this kind of transform in instcombine
with "or" instructions (but strangely only with that one opcode AFAICT).

Similar justification as in the other patch:
The line between instcombine and reassociate for these kinds of folds
is blurry. This doesn't appear to have much cost and gives us the
expected wins from repeated folds as seen in the last set of test diffs.

Differential Revision: https://reviews.llvm.org/D119955

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/InstCombine/minmax-intrinsics.ll

index d2a104b..eecd583 100644 (file)
@@ -902,6 +902,36 @@ static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) {
   return CallInst::Create(MinMax, {LHS->getArgOperand(0), NewC});
 }
 
+/// If this min/max has a matching min/max operand with a constant, try to push
+/// the constant operand into this instruction. This can enable more folds.
+static Instruction *
+reassociateMinMaxWithConstantInOperand(IntrinsicInst *II,
+                                       InstCombiner::BuilderTy &Builder) {
+  // Match and capture a min/max operand candidate.
+  Value *X, *Y;
+  Constant *C;
+  Instruction *Inner;
+  if (!match(II, m_c_MaxOrMin(m_OneUse(m_CombineAnd(
+                                  m_Instruction(Inner),
+                                  m_MaxOrMin(m_Value(X), m_ImmConstant(C)))),
+                              m_Value(Y))))
+    return nullptr;
+
+  // The inner op must match. Check for constants to avoid infinite loops.
+  Intrinsic::ID MinMaxID = II->getIntrinsicID();
+  auto *InnerMM = dyn_cast<IntrinsicInst>(Inner);
+  if (!InnerMM || InnerMM->getIntrinsicID() != MinMaxID ||
+      match(X, m_ImmConstant()) || match(Y, m_ImmConstant()))
+    return nullptr;
+
+  // max (max X, C), Y --> max (max X, Y), C
+  Function *MinMax =
+      Intrinsic::getDeclaration(II->getModule(), MinMaxID, II->getType());
+  Value *NewInner = Builder.CreateBinaryIntrinsic(MinMaxID, X, Y);
+  NewInner->takeName(Inner);
+  return CallInst::Create(MinMax, {NewInner, C});
+}
+
 /// Reduce a sequence of min/max intrinsics with a common operand.
 static Instruction *factorizeMinMaxTree(IntrinsicInst *II) {
   // Match 3 of the same min/max ops. Example: umin(umin(), umin()).
@@ -1250,6 +1280,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *NewMinMax = reassociateMinMaxWithConstants(II))
       return NewMinMax;
 
+    if (Instruction *R = reassociateMinMaxWithConstantInOperand(II, Builder))
+      return R;
+
     if (Instruction *NewMinMax = factorizeMinMaxTree(II))
        return NewMinMax;
 
index 4c95d45..bc4da86 100644 (file)
@@ -2259,8 +2259,8 @@ define i8 @umin_umin_reassoc_constant_use(i8 %x, i8 %y) {
 
 define i8 @smax_smax_reassoc_constant_sink(i8 %x, i8 %y) {
 ; CHECK-LABEL: @smax_smax_reassoc_constant_sink(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 42)
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 42)
 ; CHECK-NEXT:    ret i8 [[M2]]
 ;
   %m1 = call i8 @llvm.smax.i8(i8 %x, i8 42)
@@ -2270,8 +2270,8 @@ define i8 @smax_smax_reassoc_constant_sink(i8 %x, i8 %y) {
 
 define <3 x i8> @smin_smin_reassoc_constant_sink(<3 x i8> %x, <3 x i8> %y) {
 ; CHECK-LABEL: @smin_smin_reassoc_constant_sink(
-; CHECK-NEXT:    [[M1:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[X:%.*]], <3 x i8> <i8 43, i8 -43, i8 0>)
-; CHECK-NEXT:    [[M2:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[M1]], <3 x i8> [[Y:%.*]])
+; CHECK-NEXT:    [[M1:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[X:%.*]], <3 x i8> [[Y:%.*]])
+; CHECK-NEXT:    [[M2:%.*]] = call <3 x i8> @llvm.smin.v3i8(<3 x i8> [[M1]], <3 x i8> <i8 43, i8 -43, i8 0>)
 ; CHECK-NEXT:    ret <3 x i8> [[M2]]
 ;
   %m1 = call <3 x i8> @llvm.smin.v3i8(<3 x i8> %x, <3 x i8> <i8 43, i8 -43, i8 0>)
@@ -2281,8 +2281,8 @@ define <3 x i8> @smin_smin_reassoc_constant_sink(<3 x i8> %x, <3 x i8> %y) {
 
 define i8 @umax_umax_reassoc_constant_sink(i8 %x, i8 %y) {
 ; CHECK-LABEL: @umax_umax_reassoc_constant_sink(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 42)
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 42)
 ; CHECK-NEXT:    ret i8 [[M2]]
 ;
   %m1 = call i8 @llvm.umax.i8(i8 %x, i8 42)
@@ -2292,8 +2292,8 @@ define i8 @umax_umax_reassoc_constant_sink(i8 %x, i8 %y) {
 
 define <3 x i8> @umin_umin_reassoc_constant_sink(<3 x i8> %x, <3 x i8> %y) {
 ; CHECK-LABEL: @umin_umin_reassoc_constant_sink(
-; CHECK-NEXT:    [[M1:%.*]] = call <3 x i8> @llvm.umin.v3i8(<3 x i8> [[X:%.*]], <3 x i8> <i8 43, i8 -43, i8 0>)
-; CHECK-NEXT:    [[M2:%.*]] = call <3 x i8> @llvm.umin.v3i8(<3 x i8> [[M1]], <3 x i8> [[Y:%.*]])
+; CHECK-NEXT:    [[M1:%.*]] = call <3 x i8> @llvm.umin.v3i8(<3 x i8> [[X:%.*]], <3 x i8> [[Y:%.*]])
+; CHECK-NEXT:    [[M2:%.*]] = call <3 x i8> @llvm.umin.v3i8(<3 x i8> [[M1]], <3 x i8> <i8 43, i8 -43, i8 0>)
 ; CHECK-NEXT:    ret <3 x i8> [[M2]]
 ;
   %m1 = call <3 x i8> @llvm.umin.v3i8(<3 x i8> %x, <3 x i8> <i8 43, i8 -43, i8 0>)
@@ -2316,9 +2316,8 @@ define i8 @umin_umin_reassoc_constant_sink_use(i8 %x, i8 %y) {
 
 define i8 @smax_smax_smax_reassoc_constants(i8 %x, i8 %y) {
 ; CHECK-LABEL: @smax_smax_smax_reassoc_constants(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 42)
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[Y:%.*]], i8 [[M1]])
-; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M2]], i8 126)
+; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 126)
 ; CHECK-NEXT:    ret i8 [[M3]]
 ;
   %m1 = call i8 @llvm.smax.i8(i8 %x, i8 42)
@@ -2329,9 +2328,8 @@ define i8 @smax_smax_smax_reassoc_constants(i8 %x, i8 %y) {
 
 define i8 @smax_smax_smax_reassoc_constants_swap(i8 %x, i8 %y) {
 ; CHECK-LABEL: @smax_smax_smax_reassoc_constants_swap(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 42)
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 [[Y:%.*]])
-; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M2]], i8 126)
+; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smax.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smax.i8(i8 [[M1]], i8 126)
 ; CHECK-NEXT:    ret i8 [[M3]]
 ;
   %m1 = call i8 @llvm.smax.i8(i8 %x, i8 42)
@@ -2342,10 +2340,9 @@ define i8 @smax_smax_smax_reassoc_constants_swap(i8 %x, i8 %y) {
 
 define i8 @smin_smin_smin_reassoc_constants(i8 %x, i8 %y) {
 ; CHECK-LABEL: @smin_smin_smin_reassoc_constants(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smin.i8(i8 [[X:%.*]], i8 42)
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smin.i8(i8 [[Y:%.*]], i8 [[M1]])
-; CHECK-NEXT:    [[M3:%.*]] = call i8 @llvm.smin.i8(i8 [[M2]], i8 126)
-; CHECK-NEXT:    ret i8 [[M3]]
+; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.smin.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.smin.i8(i8 [[M1]], i8 42)
+; CHECK-NEXT:    ret i8 [[M2]]
 ;
   %m1 = call i8 @llvm.smin.i8(i8 %x, i8 42)
   %m2 = call i8 @llvm.smin.i8(i8 %y, i8 %m1)
@@ -2355,8 +2352,8 @@ define i8 @smin_smin_smin_reassoc_constants(i8 %x, i8 %y) {
 
 define i8 @umax_umax_reassoc_constantexpr_sink(i8 %x, i8 %y) {
 ; CHECK-LABEL: @umax_umax_reassoc_constantexpr_sink(
-; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 42)
-; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 ptrtoint (i8 (i8, i8)* @umax_umax_reassoc_constantexpr_sink to i8))
+; CHECK-NEXT:    [[M1:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 ptrtoint (i8 (i8, i8)* @umax_umax_reassoc_constantexpr_sink to i8))
+; CHECK-NEXT:    [[M2:%.*]] = call i8 @llvm.umax.i8(i8 [[M1]], i8 42)
 ; CHECK-NEXT:    ret i8 [[M2]]
 ;
   %m1 = call i8 @llvm.umax.i8(i8 %x, i8 42)