From d73d62c439fb1ecace5994170285da534d418173 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Wed, 5 Apr 2023 13:42:02 +0100 Subject: [PATCH] [X86] combinePredicateReduction - reuse LowerVectorAllEqual for all_of/any_of(vXi1 eq/ne) reductions --- llvm/lib/Target/X86/X86ISelLowering.cpp | 35 ++++++-------------------- llvm/test/CodeGen/X86/vector-compare-all_of.ll | 20 +++++---------- llvm/test/CodeGen/X86/vector-compare-any_of.ll | 26 +++++++------------ 3 files changed, 23 insertions(+), 58 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 02b99b7..62cbc8b 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44976,17 +44976,16 @@ static SDValue combinePredicateReduction(SDNode *Extract, SelectionDAG &DAG, ISD::CondCode CC = cast(Match.getOperand(2))->get(); if ((BinOp == ISD::AND && CC == ISD::CondCode::SETEQ) || (BinOp == ISD::OR && CC == ISD::CondCode::SETNE)) { - // If representable as a scalar integer: // For all_of(setcc(x,y,eq)) - use (iX)x == (iX)y. // For any_of(setcc(x,y,ne)) - use (iX)x != (iX)y. - EVT VecVT = Match.getOperand(0).getValueType(); - EVT IntVT = EVT::getIntegerVT(Ctx, VecVT.getSizeInBits()); - if (TLI.isTypeLegal(IntVT)) { - SDValue LHS = DAG.getFreeze(Match.getOperand(0)); - SDValue RHS = DAG.getFreeze(Match.getOperand(1)); - return DAG.getSetCC(DL, ExtractVT, DAG.getBitcast(IntVT, LHS), - DAG.getBitcast(IntVT, RHS), CC); - } + X86::CondCode X86CC; + SDValue LHS = DAG.getFreeze(Match.getOperand(0)); + SDValue RHS = DAG.getFreeze(Match.getOperand(1)); + APInt Mask = APInt::getAllOnes(LHS.getScalarValueSizeInBits()); + if (SDValue V = LowerVectorAllEqual(DL, LHS, RHS, CC, Mask, Subtarget, + DAG, X86CC)) + return DAG.getNode(ISD::TRUNCATE, DL, ExtractVT, + getSETCC(X86CC, V, DL, DAG)); } } if (TLI.isTypeLegal(MatchVT)) { @@ -44994,24 +44993,6 @@ static SDValue combinePredicateReduction(SDNode *Extract, SelectionDAG &DAG, EVT MovmskVT = EVT::getIntegerVT(Ctx, NumElts); Movmsk = DAG.getBitcast(MovmskVT, Match); } else { - // For all_of(setcc(x,y,eq)) - use PMOVMSKB(PCMPEQB()). - // For any_of(setcc(x,y,ne)) - use PMOVMSKB(NOT(PCMPEQB())). - if (Match.getOpcode() == ISD::SETCC) { - ISD::CondCode CC = cast(Match.getOperand(2))->get(); - if ((BinOp == ISD::AND && CC == ISD::CondCode::SETEQ) || - (BinOp == ISD::OR && CC == ISD::CondCode::SETNE)) { - EVT VecSVT = Match.getOperand(0).getValueType().getScalarType(); - if (VecSVT != MVT::i8 && (VecSVT.getSizeInBits() % 8) == 0) { - NumElts *= VecSVT.getSizeInBits() / 8; - EVT CmpVT = EVT::getVectorVT(Ctx, MVT::i8, NumElts); - MatchVT = EVT::getVectorVT(Ctx, MVT::i1, NumElts); - Match = DAG.getSetCC( - DL, MatchVT, DAG.getBitcast(CmpVT, Match.getOperand(0)), - DAG.getBitcast(CmpVT, Match.getOperand(1)), CC); - } - } - } - // Use combineBitcastvxi1 to create the MOVMSK. while (NumElts > MaxElts) { SDValue Lo, Hi; diff --git a/llvm/test/CodeGen/X86/vector-compare-all_of.ll b/llvm/test/CodeGen/X86/vector-compare-all_of.ll index 72f7c26..66ada15 100644 --- a/llvm/test/CodeGen/X86/vector-compare-all_of.ll +++ b/llvm/test/CodeGen/X86/vector-compare-all_of.ll @@ -1392,7 +1392,7 @@ define i1 @bool_reduction_v16i16(<16 x i16> %x, <16 x i16> %y) { ; SSE2-NEXT: pcmpeqb %xmm2, %xmm0 ; SSE2-NEXT: pand %xmm1, %xmm0 ; SSE2-NEXT: pmovmskb %xmm0, %eax -; SSE2-NEXT: cmpw $-1, %ax +; SSE2-NEXT: xorl $65535, %eax # imm = 0xFFFF ; SSE2-NEXT: sete %al ; SSE2-NEXT: retq ; @@ -1407,12 +1407,8 @@ define i1 @bool_reduction_v16i16(<16 x i16> %x, <16 x i16> %y) { ; ; AVX1-LABEL: bool_reduction_v16i16: ; AVX1: # %bb.0: -; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm2 -; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm3 -; AVX1-NEXT: vpxor %xmm2, %xmm3, %xmm2 -; AVX1-NEXT: vpxor %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vpor %xmm2, %xmm0, %xmm0 -; AVX1-NEXT: vptest %xmm0, %xmm0 +; AVX1-NEXT: vxorps %ymm1, %ymm0, %ymm0 +; AVX1-NEXT: vptest %ymm0, %ymm0 ; AVX1-NEXT: sete %al ; AVX1-NEXT: vzeroupper ; AVX1-NEXT: retq @@ -1452,7 +1448,7 @@ define i1 @bool_reduction_v32i8(<32 x i8> %x, <32 x i8> %y) { ; SSE2-NEXT: pcmpeqb %xmm2, %xmm0 ; SSE2-NEXT: pand %xmm1, %xmm0 ; SSE2-NEXT: pmovmskb %xmm0, %eax -; SSE2-NEXT: cmpw $-1, %ax +; SSE2-NEXT: xorl $65535, %eax # imm = 0xFFFF ; SSE2-NEXT: sete %al ; SSE2-NEXT: retq ; @@ -1467,12 +1463,8 @@ define i1 @bool_reduction_v32i8(<32 x i8> %x, <32 x i8> %y) { ; ; AVX1-LABEL: bool_reduction_v32i8: ; AVX1: # %bb.0: -; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm2 -; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm3 -; AVX1-NEXT: vpxor %xmm2, %xmm3, %xmm2 -; AVX1-NEXT: vpxor %xmm1, %xmm0, %xmm0 -; AVX1-NEXT: vpor %xmm2, %xmm0, %xmm0 -; AVX1-NEXT: vptest %xmm0, %xmm0 +; AVX1-NEXT: vxorps %ymm1, %ymm0, %ymm0 +; AVX1-NEXT: vptest %ymm0, %ymm0 ; AVX1-NEXT: sete %al ; AVX1-NEXT: vzeroupper ; AVX1-NEXT: retq diff --git a/llvm/test/CodeGen/X86/vector-compare-any_of.ll b/llvm/test/CodeGen/X86/vector-compare-any_of.ll index 67c0e93..1fe8d21 100644 --- a/llvm/test/CodeGen/X86/vector-compare-any_of.ll +++ b/llvm/test/CodeGen/X86/vector-compare-any_of.ll @@ -974,9 +974,9 @@ define i1 @bool_reduction_v8f32(<8 x float> %x, <8 x float> %y) { define i1 @bool_reduction_v2i64(<2 x i64> %x, <2 x i64> %y) { ; SSE2-LABEL: bool_reduction_v2i64: ; SSE2: # %bb.0: -; SSE2-NEXT: pcmpeqb %xmm1, %xmm0 -; SSE2-NEXT: pmovmskb %xmm0, %eax -; SSE2-NEXT: cmpl $65535, %eax # imm = 0xFFFF +; SSE2-NEXT: pcmpeqd %xmm1, %xmm0 +; SSE2-NEXT: movmskps %xmm0, %eax +; SSE2-NEXT: xorl $15, %eax ; SSE2-NEXT: setne %al ; SSE2-NEXT: retq ; @@ -987,20 +987,12 @@ define i1 @bool_reduction_v2i64(<2 x i64> %x, <2 x i64> %y) { ; SSE42-NEXT: setne %al ; SSE42-NEXT: retq ; -; AVX1OR2-LABEL: bool_reduction_v2i64: -; AVX1OR2: # %bb.0: -; AVX1OR2-NEXT: vpxor %xmm1, %xmm0, %xmm0 -; AVX1OR2-NEXT: vptest %xmm0, %xmm0 -; AVX1OR2-NEXT: setne %al -; AVX1OR2-NEXT: retq -; -; AVX512-LABEL: bool_reduction_v2i64: -; AVX512: # %bb.0: -; AVX512-NEXT: vpcmpneqq %xmm1, %xmm0, %k0 -; AVX512-NEXT: kmovd %k0, %eax -; AVX512-NEXT: testb %al, %al -; AVX512-NEXT: setne %al -; AVX512-NEXT: retq +; AVX-LABEL: bool_reduction_v2i64: +; AVX: # %bb.0: +; AVX-NEXT: vpxor %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vptest %xmm0, %xmm0 +; AVX-NEXT: setne %al +; AVX-NEXT: retq %a = icmp ne <2 x i64> %x, %y %b = shufflevector <2 x i1> %a, <2 x i1> undef, <2 x i32> %c = or <2 x i1> %a, %b -- 2.7.4