From 7d318b2bb19771745021145730387d43c589a9a7 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Wed, 24 Jul 2019 17:29:56 +0000 Subject: [PATCH] [DAGCombine] matchBinOpReduction - add partial reduction matching This patch adds support for recognizing cases where a larger vector type is being used to reduce just the elements in the lower subvector: e.g. <8 x i32> reduction pattern in a <16 x i32> vector: <4,5,6,7,u,u,u,u,u,u,u,u,u,u,u,u> <2,3,u,u,u,u,u,u,u,u,u,u,u,u,u,u> <1,u,u,u,u,u,u,u,u,u,u,u,u,u,u,u> matchBinOpReduction returns the lower extracted subvector in such cases, assuming isExtractSubvectorCheap accepts the extraction. I've only enabled it for X86 reduction sums so far. I intend to enable it for the bitop/minmax cases in future patches, and eventually I think its worth turning it on all the time. This is mainly just a case of ensuring calls to matchBinOpReduction don't make assumptions on the vector width based on the original vector extraction. Fixes the x86 partial reduction sum cases in PR33758 and PR42023. Differential Revision: https://reviews.llvm.org/D65047 llvm-svn: 366933 --- llvm/include/llvm/CodeGen/SelectionDAG.h | 7 +- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 39 +++++-- llvm/lib/Target/X86/X86ISelLowering.cpp | 17 ++- llvm/test/CodeGen/X86/phaddsub-extract.ll | 143 +++++-------------------- 4 files changed, 69 insertions(+), 137 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 12a9708..a79d6dc 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1588,9 +1588,12 @@ public: /// Extract. The reduction must use one of the opcodes listed in /p /// CandidateBinOps and on success /p BinOp will contain the matching opcode. /// Returns the vector that is being reduced on, or SDValue() if a reduction - /// was not matched. + /// was not matched. If \p AllowPartials is set then in the case of a + /// reduction pattern that only matches the first few stages, the extracted + /// subvector of the start of the reduction is returned. SDValue matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, - ArrayRef CandidateBinOps); + ArrayRef CandidateBinOps, + bool AllowPartials = false); /// Utility function used by legalize and lowering to /// "unroll" a vector operation by splitting out the scalars and operating diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index ee9d13c..014e672 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -9005,7 +9005,8 @@ void SDNode::intersectFlagsWith(const SDNodeFlags Flags) { SDValue SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, - ArrayRef CandidateBinOps) { + ArrayRef CandidateBinOps, + bool AllowPartials) { // The pattern must end in an extract from index 0. if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT || !isNullConstant(Extract->getOperand(1))) @@ -9019,6 +9020,23 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, return Op.getOpcode() == unsigned(BinOp); })) return SDValue(); + unsigned CandidateBinOp = Op.getOpcode(); + + // Matching failed - attempt to see if we did enough stages that a partial + // reduction from a subvector is possible. + auto PartialReduction = [&](SDValue Op, unsigned NumSubElts) { + if (!AllowPartials || !Op) + return SDValue(); + EVT OpVT = Op.getValueType(); + EVT OpSVT = OpVT.getScalarType(); + EVT SubVT = EVT::getVectorVT(*getContext(), OpSVT, NumSubElts); + if (!TLI->isExtractSubvectorCheap(SubVT, OpVT, 0)) + return SDValue(); + BinOp = (ISD::NodeType)CandidateBinOp; + return getNode( + ISD::EXTRACT_SUBVECTOR, SDLoc(Op), SubVT, Op, + getConstant(0, SDLoc(Op), TLI->getVectorIdxTy(getDataLayout()))); + }; // At each stage, we're looking for something that looks like: // %s = shufflevector <8 x i32> %op, <8 x i32> undef, @@ -9030,10 +9048,15 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, // <4,5,6,7,u,u,u,u> // <2,3,u,u,u,u,u,u> // <1,u,u,u,u,u,u,u> - unsigned CandidateBinOp = Op.getOpcode(); + // While a partial reduction match would be: + // <2,3,u,u,u,u,u,u> + // <1,u,u,u,u,u,u,u> + SDValue PrevOp; for (unsigned i = 0; i < Stages; ++i) { + unsigned MaskEnd = (1 << i); + if (Op.getOpcode() != CandidateBinOp) - return SDValue(); + return PartialReduction(PrevOp, MaskEnd); SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -9049,12 +9072,14 @@ SelectionDAG::matchBinOpReduction(SDNode *Extract, ISD::NodeType &BinOp, // The first operand of the shuffle should be the same as the other operand // of the binop. if (!Shuffle || Shuffle->getOperand(0) != Op) - return SDValue(); + return PartialReduction(PrevOp, MaskEnd); // Verify the shuffle has the expected (at this stage of the pyramid) mask. - for (int Index = 0, MaskEnd = 1 << i; Index < MaskEnd; ++Index) - if (Shuffle->getMaskElt(Index) != MaskEnd + Index) - return SDValue(); + for (int Index = 0; Index < (int)MaskEnd; ++Index) + if (Shuffle->getMaskElt(Index) != (MaskEnd + Index)) + return PartialReduction(PrevOp, MaskEnd); + + PrevOp = Op; } BinOp = (ISD::NodeType)CandidateBinOp; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7d985f2..14fa8e1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35688,7 +35688,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG, // TODO: Allow FADD with reduction and/or reassociation and no-signed-zeros. ISD::NodeType Opc; - SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD}); + SDValue Rdx = DAG.matchBinOpReduction(ExtElt, Opc, {ISD::ADD}, true); if (!Rdx) return SDValue(); @@ -35697,7 +35697,7 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG, "Reduction doesn't end in an extract from index 0"); EVT VT = ExtElt->getValueType(0); - EVT VecVT = ExtElt->getOperand(0).getValueType(); + EVT VecVT = Rdx.getValueType(); if (VecVT.getScalarType() != VT) return SDValue(); @@ -35711,14 +35711,14 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG, // vXi8 reduction - sum lo/hi halves then use PSADBW. if (VT == MVT::i8) { while (Rdx.getValueSizeInBits() > 128) { - EVT RdxVT = Rdx.getValueType(); - unsigned HalfSize = RdxVT.getSizeInBits() / 2; - unsigned HalfElts = RdxVT.getVectorNumElements() / 2; + unsigned HalfSize = VecVT.getSizeInBits() / 2; + unsigned HalfElts = VecVT.getVectorNumElements() / 2; SDValue Lo = extractSubVector(Rdx, 0, DAG, DL, HalfSize); SDValue Hi = extractSubVector(Rdx, HalfElts, DAG, DL, HalfSize); Rdx = DAG.getNode(ISD::ADD, DL, Lo.getValueType(), Lo, Hi); + VecVT = Rdx.getValueType(); } - assert(Rdx.getValueType() == MVT::v16i8 && "v16i8 reduction expected"); + assert(VecVT == MVT::v16i8 && "v16i8 reduction expected"); SDValue Hi = DAG.getVectorShuffle( MVT::v16i8, DL, Rdx, Rdx, @@ -35746,15 +35746,14 @@ static SDValue combineReductionToHorizontal(SDNode *ExtElt, SelectionDAG &DAG, unsigned NumElts = VecVT.getVectorNumElements(); SDValue Hi = extract128BitVector(Rdx, NumElts / 2, DAG, DL); SDValue Lo = extract128BitVector(Rdx, 0, DAG, DL); - VecVT = EVT::getVectorVT(*DAG.getContext(), VT, NumElts / 2); - Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Hi, Lo); + Rdx = DAG.getNode(HorizOpcode, DL, Lo.getValueType(), Hi, Lo); + VecVT = Rdx.getValueType(); } if (!((VecVT == MVT::v8i16 || VecVT == MVT::v4i32) && Subtarget.hasSSSE3()) && !((VecVT == MVT::v4f32 || VecVT == MVT::v2f64) && Subtarget.hasSSE3())) return SDValue(); // extract (add (shuf X), X), 0 --> extract (hadd X, X), 0 - assert(Rdx.getValueType() == VecVT && "Unexpected reduction match"); unsigned ReductionSteps = Log2_32(VecVT.getVectorNumElements()); for (unsigned i = 0; i != ReductionSteps; ++i) Rdx = DAG.getNode(HorizOpcode, DL, VecVT, Rdx, Rdx); diff --git a/llvm/test/CodeGen/X86/phaddsub-extract.ll b/llvm/test/CodeGen/X86/phaddsub-extract.ll index 2a7039e..b7af19b 100644 --- a/llvm/test/CodeGen/X86/phaddsub-extract.ll +++ b/llvm/test/CodeGen/X86/phaddsub-extract.ll @@ -1699,8 +1699,7 @@ define i32 @partial_reduction_add_v8i32(<8 x i32> %x) { ; ; AVX-FAST-LABEL: partial_reduction_add_v8i32: ; AVX-FAST: # %bb.0: -; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vmovd %xmm0, %eax ; AVX-FAST-NEXT: vzeroupper @@ -1741,34 +1740,13 @@ define i32 @partial_reduction_add_v16i32(<16 x i32> %x) { ; AVX-SLOW-NEXT: vzeroupper ; AVX-SLOW-NEXT: retq ; -; AVX1-FAST-LABEL: partial_reduction_add_v16i32: -; AVX1-FAST: # %bb.0: -; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vmovd %xmm0, %eax -; AVX1-FAST-NEXT: vzeroupper -; AVX1-FAST-NEXT: retq -; -; AVX2-FAST-LABEL: partial_reduction_add_v16i32: -; AVX2-FAST: # %bb.0: -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vmovd %xmm0, %eax -; AVX2-FAST-NEXT: vzeroupper -; AVX2-FAST-NEXT: retq -; -; AVX512-FAST-LABEL: partial_reduction_add_v16i32: -; AVX512-FAST: # %bb.0: -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vmovd %xmm0, %eax -; AVX512-FAST-NEXT: vzeroupper -; AVX512-FAST-NEXT: retq +; AVX-FAST-LABEL: partial_reduction_add_v16i32: +; AVX-FAST: # %bb.0: +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vmovd %xmm0, %eax +; AVX-FAST-NEXT: vzeroupper +; AVX-FAST-NEXT: retq %x23 = shufflevector <16 x i32> %x, <16 x i32> undef, <16 x i32> %x0213 = add <16 x i32> %x, %x23 %x13 = shufflevector <16 x i32> %x0213, <16 x i32> undef, <16 x i32> @@ -2010,8 +1988,7 @@ define i32 @hadd32_8(<8 x i32> %x225) { ; ; AVX-FAST-LABEL: hadd32_8: ; AVX-FAST: # %bb.0: -; AVX-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-FAST-NEXT: vmovd %xmm0, %eax ; AVX-FAST-NEXT: vzeroupper @@ -2052,34 +2029,13 @@ define i32 @hadd32_16(<16 x i32> %x225) { ; AVX-SLOW-NEXT: vzeroupper ; AVX-SLOW-NEXT: retq ; -; AVX1-FAST-LABEL: hadd32_16: -; AVX1-FAST: # %bb.0: -; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vmovd %xmm0, %eax -; AVX1-FAST-NEXT: vzeroupper -; AVX1-FAST-NEXT: retq -; -; AVX2-FAST-LABEL: hadd32_16: -; AVX2-FAST: # %bb.0: -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vmovd %xmm0, %eax -; AVX2-FAST-NEXT: vzeroupper -; AVX2-FAST-NEXT: retq -; -; AVX512-FAST-LABEL: hadd32_16: -; AVX512-FAST: # %bb.0: -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vmovd %xmm0, %eax -; AVX512-FAST-NEXT: vzeroupper -; AVX512-FAST-NEXT: retq +; AVX-FAST-LABEL: hadd32_16: +; AVX-FAST: # %bb.0: +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-FAST-NEXT: vmovd %xmm0, %eax +; AVX-FAST-NEXT: vzeroupper +; AVX-FAST-NEXT: retq %x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> %x227 = add <16 x i32> %x225, %x226 %x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32> @@ -2149,8 +2105,7 @@ define i32 @hadd32_8_optsize(<8 x i32> %x225) optsize { ; ; AVX-LABEL: hadd32_8_optsize: ; AVX: # %bb.0: -; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 ; AVX-NEXT: vmovd %xmm0, %eax ; AVX-NEXT: vzeroupper @@ -2172,63 +2127,13 @@ define i32 @hadd32_16_optsize(<16 x i32> %x225) optsize { ; SSE3-NEXT: movd %xmm1, %eax ; SSE3-NEXT: retq ; -; AVX1-SLOW-LABEL: hadd32_16_optsize: -; AVX1-SLOW: # %bb.0: -; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-SLOW-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-SLOW-NEXT: vmovd %xmm0, %eax -; AVX1-SLOW-NEXT: vzeroupper -; AVX1-SLOW-NEXT: retq -; -; AVX1-FAST-LABEL: hadd32_16_optsize: -; AVX1-FAST: # %bb.0: -; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX1-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vphaddd %xmm0, %xmm0, %xmm0 -; AVX1-FAST-NEXT: vmovd %xmm0, %eax -; AVX1-FAST-NEXT: vzeroupper -; AVX1-FAST-NEXT: retq -; -; AVX2-SLOW-LABEL: hadd32_16_optsize: -; AVX2-SLOW: # %bb.0: -; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-SLOW-NEXT: vmovd %xmm0, %eax -; AVX2-SLOW-NEXT: vzeroupper -; AVX2-SLOW-NEXT: retq -; -; AVX2-FAST-LABEL: hadd32_16_optsize: -; AVX2-FAST: # %bb.0: -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX2-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX2-FAST-NEXT: vmovd %xmm0, %eax -; AVX2-FAST-NEXT: vzeroupper -; AVX2-FAST-NEXT: retq -; -; AVX512-SLOW-LABEL: hadd32_16_optsize: -; AVX512-SLOW: # %bb.0: -; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-SLOW-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-SLOW-NEXT: vmovd %xmm0, %eax -; AVX512-SLOW-NEXT: vzeroupper -; AVX512-SLOW-NEXT: retq -; -; AVX512-FAST-LABEL: hadd32_16_optsize: -; AVX512-FAST: # %bb.0: -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] -; AVX512-FAST-NEXT: vpaddd %xmm1, %xmm0, %xmm0 -; AVX512-FAST-NEXT: vmovd %xmm0, %eax -; AVX512-FAST-NEXT: vzeroupper -; AVX512-FAST-NEXT: retq +; AVX-LABEL: hadd32_16_optsize: +; AVX: # %bb.0: +; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-NEXT: vphaddd %xmm0, %xmm0, %xmm0 +; AVX-NEXT: vmovd %xmm0, %eax +; AVX-NEXT: vzeroupper +; AVX-NEXT: retq %x226 = shufflevector <16 x i32> %x225, <16 x i32> undef, <16 x i32> %x227 = add <16 x i32> %x225, %x226 %x228 = shufflevector <16 x i32> %x227, <16 x i32> undef, <16 x i32> -- 2.7.4