From: Michael Kuperstein Date: Tue, 26 Jul 2016 20:01:29 +0000 (+0000) Subject: [X86] Split out absdiff detection from SAD combine. NFC. X-Git-Tag: llvmorg-4.0.0-rc1~14141 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2dc08f7df8b8edaf60df784996e2224786b3115c;p=platform%2Fupstream%2Fllvm.git [X86] Split out absdiff detection from SAD combine. NFC. Preparation for supporting PSADBW emission for straight-line code. llvm-svn: 276798 --- diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c5e8102..67e56cc 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -30680,8 +30680,64 @@ static SDValue OptimizeConditionalInDecrement(SDNode *N, SelectionDAG &DAG) { DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp); } -static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, - const X86Subtarget &Subtarget) { +// Given a select, detect the following pattern: +// 1: %2 = zext %0 to +// 2: %3 = zext %1 to +// 3: %4 = sub nsw %2, %3 +// 4: %5 = icmp sgt %4, [0 x N] or [-1 x N] +// 5: %6 = sub nsw zeroinitializer, %4 +// 6: %7 = select %5, %4, %6 +// This is useful as it is the input into a SAD pattern. +static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, + SDValue &Op1) { + // Check the condition of the select instruction is greater-than. + SDValue SetCC = Select->getOperand(0); + if (SetCC.getOpcode() != ISD::SETCC) + return false; + ISD::CondCode CC = cast(SetCC.getOperand(2))->get(); + if (CC != ISD::SETGT) + return false; + + SDValue SelectOp1 = Select->getOperand(1); + SDValue SelectOp2 = Select->getOperand(2); + + // The second operand of the select should be the negation of the first + // operand, which is implemented as 0 - SelectOp1. + if (!(SelectOp2.getOpcode() == ISD::SUB && + ISD::isBuildVectorAllZeros(SelectOp2.getOperand(0).getNode()) && + SelectOp2.getOperand(1) == SelectOp1)) + return false; + + // The first operand of SetCC is the first operand of the select, which is the + // difference between the two input vectors. + if (SetCC.getOperand(0) != SelectOp1) + return false; + + // The second operand of the comparison can be either -1 or 0. + if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || + ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) + return false; + + // The first operand of the select is the difference between the two input + // vectors. + if (SelectOp1.getOpcode() != ISD::SUB) + return false; + + Op0 = SelectOp1.getOperand(0); + Op1 = SelectOp1.getOperand(1); + + // Check if the operands of the sub are zero-extended from vectors of i8. + if (Op0.getOpcode() != ISD::ZERO_EXTEND || + Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || + Op1.getOpcode() != ISD::ZERO_EXTEND || + Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + return false; + + return true; +} + +static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { SDLoc DL(N); EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); @@ -30701,21 +30757,8 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, if (VT.getSizeInBits() / 4 > RegSize) return SDValue(); - // Detect the following pattern: - // - // 1: %2 = zext %0 to - // 2: %3 = zext %1 to - // 3: %4 = sub nsw %2, %3 - // 4: %5 = icmp sgt %4, [0 x N] or [-1 x N] - // 5: %6 = sub nsw zeroinitializer, %4 - // 6: %7 = select %5, %4, %6 - // 7: %8 = add nsw %7, %vec.phi - // - // The last instruction must be a reduction add. The instructions 3-6 forms an - // ABSDIFF pattern. - - // The two operands of reduction add are from PHI and a select-op as in line 7 - // above. + // We know N is a reduction add, which means one of its operands is a phi. + // To match SAD, we need the other operand to be a vector select. SDValue SelectOp, Phi; if (Op0.getOpcode() == ISD::VSELECT) { SelectOp = Op0; @@ -30726,50 +30769,12 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG, } else return SDValue(); - // Check the condition of the select instruction is greater-than. - SDValue SetCC = SelectOp->getOperand(0); - if (SetCC.getOpcode() != ISD::SETCC) - return SDValue(); - ISD::CondCode CC = cast(SetCC.getOperand(2))->get(); - if (CC != ISD::SETGT) - return SDValue(); - - Op0 = SelectOp->getOperand(1); - Op1 = SelectOp->getOperand(2); - - // The second operand of SelectOp Op1 is the negation of the first operand - // Op0, which is implemented as 0 - Op0. - if (!(Op1.getOpcode() == ISD::SUB && - ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) && - Op1.getOperand(1) == Op0)) - return SDValue(); - - // The first operand of SetCC is the first operand of SelectOp, which is the - // difference between two input vectors. - if (SetCC.getOperand(0) != Op0) - return SDValue(); - - // The second operand of > comparison can be either -1 or 0. - if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || - ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) - return SDValue(); - - // The first operand of SelectOp is the difference between two input vectors. - if (Op0.getOpcode() != ISD::SUB) - return SDValue(); - - Op1 = Op0.getOperand(1); - Op0 = Op0.getOperand(0); - - // Check if the operands of the diff are zero-extended from vectors of i8. - if (Op0.getOpcode() != ISD::ZERO_EXTEND || - Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 || - Op1.getOpcode() != ISD::ZERO_EXTEND || - Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8) + // Check whether we have an abs-diff pattern feeding into the select. + if(!detectZextAbsDiff(SelectOp, Op0, Op1)) return SDValue(); // SAD pattern detected. Now build a SAD instruction and an addition for - // reduction. Note that the number of elments of the result of SAD is less + // reduction. Note that the number of elements of the result of SAD is less // than the number of elements of its input. Therefore, we could only update // part of elements in the reduction vector. @@ -30819,7 +30824,7 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { const SDNodeFlags *Flags = &cast(N)->Flags; if (Flags->hasVectorReduction()) { - if (SDValue Sad = detectSADPattern(N, DAG, Subtarget)) + if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget)) return Sad; } EVT VT = N->getValueType(0);