From 2887f1463930044a6093f111dc8eba5594144c33 Mon Sep 17 00:00:00 2001 From: David Green Date: Sat, 26 Jun 2021 19:34:16 +0100 Subject: [PATCH] [ISel] Port AArch64 SABD and UABD to DAGCombine This ports the AArch64 SABD and USBD over to DAG Combine, where they can be used by more backends (notably MVE in a follow-up patch). The matching code has changed very little, just to handle legal operations and types differently. It selects from (ABS (SUB (EXTEND a), (EXTEND b))), producing a ubds/abdu which is zexted to the original type. Differential Revision: https://reviews.llvm.org/D91937 --- llvm/include/llvm/CodeGen/ISDOpcodes.h | 7 +++ llvm/include/llvm/Target/TargetSelectionDAG.td | 2 + llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 38 ++++++++++++ .../CodeGen/SelectionDAG/SelectionDAGDumper.cpp | 2 + llvm/lib/CodeGen/TargetLoweringBase.cpp | 4 ++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 68 +++++----------------- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 4 -- llvm/lib/Target/AArch64/AArch64InstrInfo.td | 7 +-- 8 files changed, 69 insertions(+), 63 deletions(-) diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h index adad8c1..6eb70ab 100644 --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -611,6 +611,13 @@ enum NodeType { MULHU, MULHS, + // ABDS/ABDU - Absolute difference - Return the absolute difference between + // two numbers interpreted as signed/unsigned. + // i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1) + // or trunc(abs(zext(Op0) - zext(Op1))) becomes abdu(Op0, Op1) + ABDS, + ABDU, + /// [US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned /// integers. SMIN, diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td index 1913396..c7f22bf 100644 --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -369,6 +369,8 @@ def mul : SDNode<"ISD::MUL" , SDTIntBinOp, [SDNPCommutative, SDNPAssociative]>; def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>; def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>; +def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>; +def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>; def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>; def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>; def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>; diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 63c979c..5ea3de9 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9071,6 +9071,40 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) { return SDValue(); } +// Given a ABS node, detect the following pattern: +// (ABS (SUB (EXTEND a), (EXTEND b))). +// Generates UABD/SABD instruction. +static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG, + const TargetLowering &TLI) { + SDValue AbsOp1 = N->getOperand(0); + SDValue Op0, Op1; + + if (AbsOp1.getOpcode() != ISD::SUB) + return SDValue(); + + Op0 = AbsOp1.getOperand(0); + Op1 = AbsOp1.getOperand(1); + + unsigned Opc0 = Op0.getOpcode(); + // Check if the operands of the sub are (zero|sign)-extended. + if (Opc0 != Op1.getOpcode() || + (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) + return SDValue(); + + EVT VT1 = Op0.getOperand(0).getValueType(); + EVT VT2 = Op1.getOperand(0).getValueType(); + // Check if the operands are of same type and valid size. + unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU; + if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) + return SDValue(); + + Op0 = Op0.getOperand(0); + Op1 = Op1.getOperand(0); + SDValue ABD = + DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); +} + SDValue DAGCombiner::visitABS(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -9084,6 +9118,10 @@ SDValue DAGCombiner::visitABS(SDNode *N) { // fold (abs x) -> x iff not-negative if (DAG.SignBitIsZero(N0)) return N0; + + if (SDValue ABD = combineABSToABD(N, DAG, TLI)) + return ABD; + return SDValue(); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index 73c207e..40083c6 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -231,6 +231,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const { case ISD::MUL: return "mul"; case ISD::MULHU: return "mulhu"; case ISD::MULHS: return "mulhs"; + case ISD::ABDS: return "abds"; + case ISD::ABDU: return "abdu"; case ISD::SDIV: return "sdiv"; case ISD::UDIV: return "udiv"; case ISD::SREM: return "srem"; diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index d2c291f..ebac779 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -813,6 +813,10 @@ void TargetLoweringBase::initActions() { setOperationAction(ISD::SUBC, VT, Expand); setOperationAction(ISD::SUBE, VT, Expand); + // Absolute difference + setOperationAction(ISD::ABDS, VT, Expand); + setOperationAction(ISD::ABDU, VT, Expand); + // These default to Expand so they will be expanded to CTLZ/CTTZ by default. setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand); setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Expand); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index fd5c9e0..9886d63 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1050,6 +1050,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::USUBSAT, VT, Legal); } + for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16, + MVT::v4i32}) { + setOperationAction(ISD::ABDS, VT, Legal); + setOperationAction(ISD::ABDU, VT, Legal); + } + // Vector reductions for (MVT VT : { MVT::v4f16, MVT::v2f32, MVT::v8f16, MVT::v4f32, MVT::v2f64 }) { @@ -2116,8 +2122,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::INDEX_VECTOR) - MAKE_CASE(AArch64ISD::UABD) - MAKE_CASE(AArch64ISD::SABD) MAKE_CASE(AArch64ISD::UADDLP) MAKE_CASE(AArch64ISD::CALL_RVMARKER) } @@ -4082,8 +4086,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, } case Intrinsic::aarch64_neon_sabd: case Intrinsic::aarch64_neon_uabd: { - unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD - : AArch64ISD::SABD; + unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU + : ISD::ABDS; return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } @@ -12099,8 +12103,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N, SDValue UABDHigh8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0), DAG.getConstant(8, DL, MVT::i64)); - SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD, - DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1); + SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8, + UABDHigh8Op0, UABDHigh8Op1); SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8); // Second, create the node pattern of UABAL. @@ -12110,8 +12114,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N, SDValue UABDLo8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0), DAG.getConstant(0, DL, MVT::i64)); - SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD, - DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1); + SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8, + UABDLo8Op0, UABDLo8Op1); SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8); SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD); @@ -12170,48 +12174,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); } -// Given a ABS node, detect the following pattern: -// (ABS (SUB (EXTEND a), (EXTEND b))). -// Generates UABD/SABD instruction. -static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const AArch64Subtarget *Subtarget) { - SDValue AbsOp1 = N->getOperand(0); - SDValue Op0, Op1; - - if (AbsOp1.getOpcode() != ISD::SUB) - return SDValue(); - - Op0 = AbsOp1.getOperand(0); - Op1 = AbsOp1.getOperand(1); - - unsigned Opc0 = Op0.getOpcode(); - // Check if the operands of the sub are (zero|sign)-extended. - if (Opc0 != Op1.getOpcode() || - (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) - return SDValue(); - - EVT VectorT1 = Op0.getOperand(0).getValueType(); - EVT VectorT2 = Op1.getOperand(0).getValueType(); - // Check if vectors are of same type and valid size. - uint64_t Size = VectorT1.getFixedSizeInBits(); - if (VectorT1 != VectorT2 || (Size != 64 && Size != 128)) - return SDValue(); - - // Check if vector element types are valid. - EVT VT1 = VectorT1.getVectorElementType(); - if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32) - return SDValue(); - - Op0 = Op0.getOperand(0); - Op1 = Op1.getOperand(0); - unsigned ABDOpcode = - (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD; - SDValue ABD = - DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); -} - static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -14377,8 +14339,8 @@ static SDValue performExtendCombine(SDNode *N, // helps the backend to decide that an sabdl2 would be useful, saving a real // extract_high operation. if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND && - (N->getOperand(0).getOpcode() == AArch64ISD::UABD || - N->getOperand(0).getOpcode() == AArch64ISD::SABD)) { + (N->getOperand(0).getOpcode() == ISD::ABDU || + N->getOperand(0).getOpcode() == ISD::ABDS)) { SDNode *ABDNode = N->getOperand(0).getNode(); SDValue NewABD = tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG); @@ -16344,8 +16306,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, default: LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); break; - case ISD::ABS: - return performABSCombine(N, DAG, DCI, Subtarget); case ISD::ADD: case ISD::SUB: return performAddSubCombine(N, DCI, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index ced2607..7daa619 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -236,10 +236,6 @@ enum NodeType : unsigned { SRHADD, URHADD, - // Absolute difference - UABD, - SABD, - // Unsigned Add Long Pairwise UADDLP, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 301f1ed6..7802144 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -579,14 +579,11 @@ def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>; def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>; def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>; -def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>; -def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>; - def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs), - [(AArch64uabd_n node:$lhs, node:$rhs), + [(abdu node:$lhs, node:$rhs), (int_aarch64_neon_uabd node:$lhs, node:$rhs)]>; def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs), - [(AArch64sabd_n node:$lhs, node:$rhs), + [(abds node:$lhs, node:$rhs), (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>; def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>; -- 2.7.4