From 978f827d122bad37e21081fc32fad6b7d8b6edea Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Sat, 29 Oct 2016 15:22:04 +0000 Subject: [PATCH] [InstCombine] re-use bitcasted compare operands in selects (PR28001) These mixed bitcast patterns show up with SSE/AVX intrinsics because we bitcast function parameters to <2 x i64>. The bitcasts obfuscate the expected min/max forms as shown in PR28001: https://llvm.org/bugs/show_bug.cgi?id=28001#c6 Differential Revision: https://reviews.llvm.org/D25943 llvm-svn: 285495 --- .../Transforms/InstCombine/InstCombineSelect.cpp | 50 ++++++++++++++++++++++ llvm/test/Transforms/InstCombine/minmax-fold.ll | 15 +++---- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 165b54b..af6b013 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1009,6 +1009,53 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { ConstantVector::get(Mask)); } +/// Reuse bitcasted operands between a compare and select: +/// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> +/// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D)) +static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + + CmpInst::Predicate Pred; + Value *A, *B; + if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B)))) + return nullptr; + + // The select condition is a compare instruction. If the select's true/false + // values are already the same as the compare operands, there's nothing to do. + if (TVal == A || TVal == B || FVal == A || FVal == B) + return nullptr; + + Value *C, *D; + if (!match(A, m_BitCast(m_Value(C))) || !match(B, m_BitCast(m_Value(D)))) + return nullptr; + + // select (cmp (bitcast C), (bitcast D)), (bitcast TSrc), (bitcast FSrc) + Value *TSrc, *FSrc; + if (!match(TVal, m_BitCast(m_Value(TSrc))) || + !match(FVal, m_BitCast(m_Value(FSrc)))) + return nullptr; + + // If the select true/false values are *different bitcasts* of the same source + // operands, make the select operands the same as the compare operands and + // cast the result. This is the canonical select form for min/max. + Value *NewSel; + if (TSrc == C && FSrc == D) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> + // bitcast (select (cmp A, B), A, B) + NewSel = Builder.CreateSelect(Cond, A, B, "", &Sel); + } else if (TSrc == D && FSrc == C) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' D), (bitcast' C) --> + // bitcast (select (cmp A, B), B, A) + NewSel = Builder.CreateSelect(Cond, B, A, "", &Sel); + } else { + return nullptr; + } + return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1369,5 +1416,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder)) + return BitCastSel; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/minmax-fold.ll b/llvm/test/Transforms/InstCombine/minmax-fold.ll index 9cbaed3..aa1ab6d 100644 --- a/llvm/test/Transforms/InstCombine/minmax-fold.ll +++ b/llvm/test/Transforms/InstCombine/minmax-fold.ll @@ -144,9 +144,8 @@ define <4 x i32> @bitcasts_fcmp_1(<2 x i64> %a, <2 x i64> %b) { ; CHECK-NEXT: [[T0:%.*]] = bitcast <2 x i64> %a to <4 x float> ; CHECK-NEXT: [[T1:%.*]] = bitcast <2 x i64> %b to <4 x float> ; CHECK-NEXT: [[T2:%.*]] = fcmp olt <4 x float> [[T1]], [[T0]] -; CHECK-NEXT: [[T3:%.*]] = bitcast <2 x i64> %a to <4 x i32> -; CHECK-NEXT: [[T4:%.*]] = bitcast <2 x i64> %b to <4 x i32> -; CHECK-NEXT: [[T5:%.*]] = select <4 x i1> [[T2]], <4 x i32> [[T3]], <4 x i32> [[T4]] +; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[T2]], <4 x float> [[T0]], <4 x float> [[T1]] +; CHECK-NEXT: [[T5:%.*]] = bitcast <4 x float> [[TMP1]] to <4 x i32> ; CHECK-NEXT: ret <4 x i32> [[T5]] ; %t0 = bitcast <2 x i64> %a to <4 x float> @@ -165,9 +164,8 @@ define <4 x i32> @bitcasts_fcmp_2(<2 x i64> %a, <2 x i64> %b) { ; CHECK-NEXT: [[T0:%.*]] = bitcast <2 x i64> %a to <4 x float> ; CHECK-NEXT: [[T1:%.*]] = bitcast <2 x i64> %b to <4 x float> ; CHECK-NEXT: [[T2:%.*]] = fcmp olt <4 x float> [[T0]], [[T1]] -; CHECK-NEXT: [[T3:%.*]] = bitcast <2 x i64> %a to <4 x i32> -; CHECK-NEXT: [[T4:%.*]] = bitcast <2 x i64> %b to <4 x i32> -; CHECK-NEXT: [[T5:%.*]] = select <4 x i1> [[T2]], <4 x i32> [[T3]], <4 x i32> [[T4]] +; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[T2]], <4 x float> [[T0]], <4 x float> [[T1]] +; CHECK-NEXT: [[T5:%.*]] = bitcast <4 x float> [[TMP1]] to <4 x i32> ; CHECK-NEXT: ret <4 x i32> [[T5]] ; %t0 = bitcast <2 x i64> %a to <4 x float> @@ -186,9 +184,8 @@ define <4 x float> @bitcasts_icmp(<2 x i64> %a, <2 x i64> %b) { ; CHECK-NEXT: [[T0:%.*]] = bitcast <2 x i64> %a to <4 x i32> ; CHECK-NEXT: [[T1:%.*]] = bitcast <2 x i64> %b to <4 x i32> ; CHECK-NEXT: [[T2:%.*]] = icmp slt <4 x i32> [[T1]], [[T0]] -; CHECK-NEXT: [[T3:%.*]] = bitcast <2 x i64> %a to <4 x float> -; CHECK-NEXT: [[T4:%.*]] = bitcast <2 x i64> %b to <4 x float> -; CHECK-NEXT: [[T5:%.*]] = select <4 x i1> [[T2]], <4 x float> [[T3]], <4 x float> [[T4]] +; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[T2]], <4 x i32> [[T0]], <4 x i32> [[T1]] +; CHECK-NEXT: [[T5:%.*]] = bitcast <4 x i32> [[TMP1]] to <4 x float> ; CHECK-NEXT: ret <4 x float> [[T5]] ; %t0 = bitcast <2 x i64> %a to <4 x i32> -- 2.7.4