From: Andrea Di Biagio Date: Wed, 11 Jun 2014 07:57:50 +0000 (+0000) Subject: [X86] Refactor the logic to select horizontal adds/subs to a helper function. X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c7af75f9a7701e48d69cf01d62efa46a215e9da0;p=platform%2Fupstream%2Fllvm.git [X86] Refactor the logic to select horizontal adds/subs to a helper function. This patch moves part of the logic implemented by the target specific combine rules added at r210477 to a separate helper function. This should make easier to add more rules for matching AVX/AVX2 horizontal adds/subs. This patch also fixes a problem caused by a wrong check performed on indices of extract_vector_elt dag nodes in input to the scalar adds/subs. New tests have been added to verify that we correctly check indices of extract_vector_elt dag nodes when selecting a horizontal operation. llvm-svn: 210644 --- diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c7c8cb5..8cf3b53 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -6057,102 +6057,130 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(ISD::BITCAST, dl, VT, Select); } -static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, - const X86Subtarget *Subtarget) { - EVT VT = N->getValueType(0); +/// \brief Return true if \p N implements a horizontal binop and return the +/// operands for the horizontal binop into V0 and V1. +/// +/// This is a helper function of PerformBUILD_VECTORCombine. +/// This function checks that the build_vector \p N in input implements a +/// horizontal operation. Parameter \p Opcode defines the kind of horizontal +/// operation to match. +/// For example, if \p Opcode is equal to ISD::ADD, then this function +/// checks if \p N implements a horizontal arithmetic add; if instead \p Opcode +/// is equal to ISD::SUB, then this function checks if this is a horizontal +/// arithmetic sub. +/// +/// This function only analyzes elements of \p N whose indices are +/// in range [BaseIdx, LastIdx). +static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, + unsigned BaseIdx, unsigned LastIdx, + SDValue &V0, SDValue &V1) { + assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!"); + assert(N->getValueType(0).isVector() && + N->getValueType(0).getVectorNumElements() >= LastIdx && + "Invalid Vector in input!"); + + bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD); + bool CanFold = true; + unsigned ExpectedVExtractIdx = BaseIdx; + unsigned NumElts = LastIdx - BaseIdx; - // Try to match a horizontal ADD or SUB. - if (((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) || - ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) || - ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || - VT == MVT::v16i16) && Subtarget->hasAVX())) { - unsigned NumOperands = N->getNumOperands(); - unsigned Opcode = N->getOperand(0)->getOpcode(); - bool isCommutable = false; - bool CanFold = false; - switch (Opcode) { - default : break; - case ISD::ADD : - case ISD::FADD : - isCommutable = true; - // FALL-THROUGH - case ISD::SUB : - case ISD::FSUB : - CanFold = true; - } - - // Verify that operands have the same opcode; also, the opcode can only - // be either of: ADD, FADD, SUB, FSUB. - SDValue InVec0, InVec1; - for (unsigned i = 0, e = NumOperands; i != e && CanFold; ++i) { - SDValue Op = N->getOperand(i); - CanFold = Op->getOpcode() == Opcode && Op->hasOneUse(); - - if (!CanFold) - break; + // Check if N implements a horizontal binop. + for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) { + SDValue Op = N->getOperand(i + BaseIdx); + CanFold = Op->getOpcode() == Opcode && Op->hasOneUse(); - SDValue Op0 = Op.getOperand(0); - SDValue Op1 = Op.getOperand(1); - - // Try to match the following pattern: - // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1)) - CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT && - Op0.getOperand(0) == Op1.getOperand(0) && - isa(Op0.getOperand(1)) && - isa(Op1.getOperand(1))); - if (!CanFold) - break; + if (!CanFold) + break; - unsigned I0 = cast(Op0.getOperand(1))->getZExtValue(); - unsigned I1 = cast(Op1.getOperand(1))->getZExtValue(); - unsigned ExpectedIndex = (i * 2) % NumOperands; - - if (i == 0) - InVec0 = Op0.getOperand(0); - else if (i * 2 == NumOperands) - InVec1 = Op0.getOperand(0); - - SDValue Expected = (i * 2 < NumOperands) ? InVec0 : InVec1; - if (I0 == ExpectedIndex) - CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected; - else if (isCommutable && I1 == ExpectedIndex) { - // Try to see if we can match the following dag sequence: - // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I)) - CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected; - } - } + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + // Try to match the following pattern: + // (BINOP (extract_vector_elt A, I), (extract_vector_elt A, I+1)) + CanFold = (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT && + Op0.getOperand(0) == Op1.getOperand(0) && + isa(Op0.getOperand(1)) && + isa(Op1.getOperand(1))); + if (!CanFold) + break; - if (CanFold) { - unsigned NewOpcode; - switch (Opcode) { - default : llvm_unreachable("Unexpected opcode found!"); - case ISD::ADD : NewOpcode = X86ISD::HADD; break; - case ISD::FADD : NewOpcode = X86ISD::FHADD; break; - case ISD::SUB : NewOpcode = X86ISD::HSUB; break; - case ISD::FSUB : NewOpcode = X86ISD::FHSUB; break; - } + unsigned I0 = cast(Op0.getOperand(1))->getZExtValue(); + unsigned I1 = cast(Op1.getOperand(1))->getZExtValue(); - if (VT.is256BitVector()) { - SDLoc dl(N); - - // Convert this sequence into two horizontal add/sub followed - // by a concat vector. - SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, dl); - SDValue InVec0_HI = - Extract128BitVector(InVec0, NumOperands/2, DAG, dl); - SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, dl); - SDValue InVec1_HI = - Extract128BitVector(InVec1, NumOperands/2, DAG, dl); - EVT NewVT = InVec0_LO.getValueType(); - - SDValue LO = DAG.getNode(NewOpcode, dl, NewVT, InVec0_LO, InVec0_HI); - SDValue HI = DAG.getNode(NewOpcode, dl, NewVT, InVec1_LO, InVec1_HI); - return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, LO, HI); - } + if (i == 0) + V0 = Op0.getOperand(0); + else if (i * 2 == NumElts) { + V1 = Op0.getOperand(0); + ExpectedVExtractIdx = BaseIdx; + } + + SDValue Expected = (i * 2 < NumElts) ? V0 : V1; + if (I0 == ExpectedVExtractIdx) + CanFold = I1 == I0 + 1 && Op0.getOperand(0) == Expected; + else if (IsCommutable && I1 == ExpectedVExtractIdx) { + // Try to match the following dag sequence: + // (BINOP (extract_vector_elt A, I+1), (extract_vector_elt A, I)) + CanFold = I0 == I1 + 1 && Op1.getOperand(0) == Expected; + } else + CanFold = false; - return DAG.getNode(NewOpcode, SDLoc(N), VT, InVec0, InVec1); - } + ExpectedVExtractIdx += 2; + } + + return CanFold; +} + +static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, + const X86Subtarget *Subtarget) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + unsigned NumElts = VT.getVectorNumElements(); + BuildVectorSDNode *BV = cast(N); + SDValue InVec0, InVec1; + + // Try to match horizontal ADD/SUB. + if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) { + // Try to match an SSE3 float HADD/HSUB. + if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); + + if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); + } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) { + // Try to match an SSSE3 integer HADD/HSUB. + if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1); + + if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1)) + return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1); + } + + if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || + VT == MVT::v16i16) && Subtarget->hasAVX()) { + unsigned X86Opcode; + if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::HADD; + else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::HSUB; + else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::FHADD; + else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1)) + X86Opcode = X86ISD::FHSUB; + else + return SDValue(); + + // Convert this build_vector into two horizontal add/sub followed by + // a concat vector. + SDValue InVec0_LO = Extract128BitVector(InVec0, 0, DAG, DL); + SDValue InVec0_HI = Extract128BitVector(InVec0, NumElts/2, DAG, DL); + SDValue InVec1_LO = Extract128BitVector(InVec1, 0, DAG, DL); + SDValue InVec1_HI = Extract128BitVector(InVec1, NumElts/2, DAG, DL); + EVT NewVT = InVec0_LO.getValueType(); + + SDValue LO = DAG.getNode(X86Opcode, DL, NewVT, InVec0_LO, InVec0_HI); + SDValue HI = DAG.getNode(X86Opcode, DL, NewVT, InVec1_LO, InVec1_HI); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI); } return SDValue(); diff --git a/llvm/test/CodeGen/X86/haddsub-2.ll b/llvm/test/CodeGen/X86/haddsub-2.ll index 7b875c0..72217b3 100644 --- a/llvm/test/CodeGen/X86/haddsub-2.ll +++ b/llvm/test/CodeGen/X86/haddsub-2.ll @@ -86,12 +86,12 @@ define <4 x float> @hsub_ps_test2(<4 x float> %A, <4 x float> %B) { %vecext3 = extractelement <4 x float> %A, i32 1 %sub4 = fsub float %vecext2, %vecext3 %vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0 - %vecext6 = extractelement <4 x float> %B, i32 3 - %vecext7 = extractelement <4 x float> %B, i32 2 + %vecext6 = extractelement <4 x float> %B, i32 2 + %vecext7 = extractelement <4 x float> %B, i32 3 %sub8 = fsub float %vecext6, %vecext7 %vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3 - %vecext10 = extractelement <4 x float> %B, i32 1 - %vecext11 = extractelement <4 x float> %B, i32 0 + %vecext10 = extractelement <4 x float> %B, i32 0 + %vecext11 = extractelement <4 x float> %B, i32 1 %sub12 = fsub float %vecext10, %vecext11 %vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2 ret <4 x float> %vecinit13 @@ -137,12 +137,12 @@ define <4 x i32> @phadd_d_test2(<4 x i32> %A, <4 x i32> %B) { %vecext3 = extractelement <4 x i32> %A, i32 1 %add4 = add i32 %vecext2, %vecext3 %vecinit5 = insertelement <4 x i32> %vecinit, i32 %add4, i32 0 - %vecext6 = extractelement <4 x i32> %B, i32 2 - %vecext7 = extractelement <4 x i32> %B, i32 3 + %vecext6 = extractelement <4 x i32> %B, i32 3 + %vecext7 = extractelement <4 x i32> %B, i32 2 %add8 = add i32 %vecext6, %vecext7 %vecinit9 = insertelement <4 x i32> %vecinit5, i32 %add8, i32 3 - %vecext10 = extractelement <4 x i32> %B, i32 0 - %vecext11 = extractelement <4 x i32> %B, i32 1 + %vecext10 = extractelement <4 x i32> %B, i32 1 + %vecext11 = extractelement <4 x i32> %B, i32 0 %add12 = add i32 %vecext10, %vecext11 %vecinit13 = insertelement <4 x i32> %vecinit9, i32 %add12, i32 2 ret <4 x i32> %vecinit13 @@ -191,12 +191,12 @@ define <4 x i32> @phsub_d_test2(<4 x i32> %A, <4 x i32> %B) { %vecext3 = extractelement <4 x i32> %A, i32 1 %sub4 = sub i32 %vecext2, %vecext3 %vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 0 - %vecext6 = extractelement <4 x i32> %B, i32 3 - %vecext7 = extractelement <4 x i32> %B, i32 2 + %vecext6 = extractelement <4 x i32> %B, i32 2 + %vecext7 = extractelement <4 x i32> %B, i32 3 %sub8 = sub i32 %vecext6, %vecext7 %vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 3 - %vecext10 = extractelement <4 x i32> %B, i32 1 - %vecext11 = extractelement <4 x i32> %B, i32 0 + %vecext10 = extractelement <4 x i32> %B, i32 0 + %vecext11 = extractelement <4 x i32> %B, i32 1 %sub12 = sub i32 %vecext10, %vecext11 %vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 2 ret <4 x i32> %vecinit13 @@ -258,14 +258,14 @@ define <2 x double> @hsub_pd_test1(<2 x double> %A, <2 x double> %B) { define <2 x double> @hsub_pd_test2(<2 x double> %A, <2 x double> %B) { - %vecext = extractelement <2 x double> %A, i32 1 - %vecext1 = extractelement <2 x double> %A, i32 0 + %vecext = extractelement <2 x double> %B, i32 0 + %vecext1 = extractelement <2 x double> %B, i32 1 %sub = fsub double %vecext, %vecext1 - %vecinit = insertelement <2 x double> undef, double %sub, i32 0 - %vecext2 = extractelement <2 x double> %B, i32 1 - %vecext3 = extractelement <2 x double> %B, i32 0 + %vecinit = insertelement <2 x double> undef, double %sub, i32 1 + %vecext2 = extractelement <2 x double> %A, i32 0 + %vecext3 = extractelement <2 x double> %A, i32 1 %sub2 = fsub double %vecext2, %vecext3 - %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 1 + %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0 ret <2 x double> %vecinit2 } ; CHECK-LABEL: hsub_pd_test2 @@ -458,3 +458,68 @@ define <16 x i16> @avx2_vphadd_w_test(<16 x i16> %a, <16 x i16> %b) { ; CHECK: ret +; Verify that we don't select horizontal subs in the following functions. + +define <4 x i32> @not_a_hsub_1(<4 x i32> %A, <4 x i32> %B) { + %vecext = extractelement <4 x i32> %A, i32 0 + %vecext1 = extractelement <4 x i32> %A, i32 1 + %sub = sub i32 %vecext, %vecext1 + %vecinit = insertelement <4 x i32> undef, i32 %sub, i32 0 + %vecext2 = extractelement <4 x i32> %A, i32 2 + %vecext3 = extractelement <4 x i32> %A, i32 3 + %sub4 = sub i32 %vecext2, %vecext3 + %vecinit5 = insertelement <4 x i32> %vecinit, i32 %sub4, i32 1 + %vecext6 = extractelement <4 x i32> %B, i32 1 + %vecext7 = extractelement <4 x i32> %B, i32 0 + %sub8 = sub i32 %vecext6, %vecext7 + %vecinit9 = insertelement <4 x i32> %vecinit5, i32 %sub8, i32 2 + %vecext10 = extractelement <4 x i32> %B, i32 3 + %vecext11 = extractelement <4 x i32> %B, i32 2 + %sub12 = sub i32 %vecext10, %vecext11 + %vecinit13 = insertelement <4 x i32> %vecinit9, i32 %sub12, i32 3 + ret <4 x i32> %vecinit13 +} +; CHECK-LABEL: not_a_hsub_1 +; CHECK-NOT: phsubd +; CHECK: ret + + +define <4 x float> @not_a_hsub_2(<4 x float> %A, <4 x float> %B) { + %vecext = extractelement <4 x float> %A, i32 2 + %vecext1 = extractelement <4 x float> %A, i32 3 + %sub = fsub float %vecext, %vecext1 + %vecinit = insertelement <4 x float> undef, float %sub, i32 1 + %vecext2 = extractelement <4 x float> %A, i32 0 + %vecext3 = extractelement <4 x float> %A, i32 1 + %sub4 = fsub float %vecext2, %vecext3 + %vecinit5 = insertelement <4 x float> %vecinit, float %sub4, i32 0 + %vecext6 = extractelement <4 x float> %B, i32 3 + %vecext7 = extractelement <4 x float> %B, i32 2 + %sub8 = fsub float %vecext6, %vecext7 + %vecinit9 = insertelement <4 x float> %vecinit5, float %sub8, i32 3 + %vecext10 = extractelement <4 x float> %B, i32 0 + %vecext11 = extractelement <4 x float> %B, i32 1 + %sub12 = fsub float %vecext10, %vecext11 + %vecinit13 = insertelement <4 x float> %vecinit9, float %sub12, i32 2 + ret <4 x float> %vecinit13 +} +; CHECK-LABEL: not_a_hsub_2 +; CHECK-NOT: hsubps +; CHECK: ret + + +define <2 x double> @not_a_hsub_3(<2 x double> %A, <2 x double> %B) { + %vecext = extractelement <2 x double> %B, i32 0 + %vecext1 = extractelement <2 x double> %B, i32 1 + %sub = fsub double %vecext, %vecext1 + %vecinit = insertelement <2 x double> undef, double %sub, i32 1 + %vecext2 = extractelement <2 x double> %A, i32 1 + %vecext3 = extractelement <2 x double> %A, i32 0 + %sub2 = fsub double %vecext2, %vecext3 + %vecinit2 = insertelement <2 x double> %vecinit, double %sub2, i32 0 + ret <2 x double> %vecinit2 +} +; CHECK-LABEL: not_a_hsub_3 +; CHECK-NOT: hsubpd +; CHECK: ret +