From 08ce52ef5e6b879216f8018b920ef5c0621e797d Mon Sep 17 00:00:00 2001 From: Jingu Kang Date: Thu, 10 Jun 2021 16:02:57 +0100 Subject: [PATCH] [AArch64] Improve SAD pattern Given a vecreduce_add node, detect the below pattern and convert it to the node sequence with UABDL, [S|U]ADB and UADDLP. i32 vecreduce_add( v16i32 abs( v16i32 sub( v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b)))) =================> i32 vecreduce_add( v4i32 UADDLP( v8i16 add( v8i16 zext( v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b v8i16 zext( v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b Differential Revision: https://reviews.llvm.org/D104042 --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 100 +++++++++++++++++++++++- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 + llvm/lib/Target/AArch64/AArch64InstrInfo.td | 12 ++- llvm/test/CodeGen/AArch64/arm64-vabs.ll | 18 ++--- llvm/test/CodeGen/AArch64/neon-sad.ll | 18 ++--- 5 files changed, 123 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 0f89bff..9b0735a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2110,6 +2110,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::INDEX_VECTOR) MAKE_CASE(AArch64ISD::UABD) MAKE_CASE(AArch64ISD::SABD) + MAKE_CASE(AArch64ISD::UADDLP) MAKE_CASE(AArch64ISD::CALL_RVMARKER) } #undef MAKE_CASE @@ -4078,6 +4079,10 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } + case Intrinsic::aarch64_neon_uaddlp: { + unsigned Opcode = AArch64ISD::UADDLP; + return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1)); + } case Intrinsic::aarch64_neon_sdot: case Intrinsic::aarch64_neon_udot: case Intrinsic::aarch64_sve_sdot: @@ -11981,13 +11986,106 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG, return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0)); } +// Given a vecreduce_add node, detect the below pattern and convert it to the +// node sequence with UABDL, [S|U]ADB and UADDLP. +// +// i32 vecreduce_add( +// v16i32 abs( +// v16i32 sub( +// v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b)))) +// =================> +// i32 vecreduce_add( +// v4i32 UADDLP( +// v8i16 add( +// v8i16 zext( +// v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b +// v8i16 zext( +// v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b +static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N, + SelectionDAG &DAG) { + // Assumed i32 vecreduce_add + if (N->getValueType(0) != MVT::i32) + return SDValue(); + + SDValue VecReduceOp0 = N->getOperand(0); + unsigned Opcode = VecReduceOp0.getOpcode(); + // Assumed v16i32 abs + if (Opcode != ISD::ABS || VecReduceOp0->getValueType(0) != MVT::v16i32) + return SDValue(); + + SDValue ABS = VecReduceOp0; + // Assumed v16i32 sub + if (ABS->getOperand(0)->getOpcode() != ISD::SUB || + ABS->getOperand(0)->getValueType(0) != MVT::v16i32) + return SDValue(); + + SDValue SUB = ABS->getOperand(0); + unsigned Opcode0 = SUB->getOperand(0).getOpcode(); + unsigned Opcode1 = SUB->getOperand(1).getOpcode(); + // Assumed v16i32 type + if (SUB->getOperand(0)->getValueType(0) != MVT::v16i32 || + SUB->getOperand(1)->getValueType(0) != MVT::v16i32) + return SDValue(); + + // Assumed zext or sext + bool IsZExt = false; + if (Opcode0 == ISD::ZERO_EXTEND && Opcode1 == ISD::ZERO_EXTEND) { + IsZExt = true; + } else if (Opcode0 == ISD::SIGN_EXTEND && Opcode1 == ISD::SIGN_EXTEND) { + IsZExt = false; + } else + return SDValue(); + + SDValue EXT0 = SUB->getOperand(0); + SDValue EXT1 = SUB->getOperand(1); + // Assumed zext's operand has v16i8 type + if (EXT0->getOperand(0)->getValueType(0) != MVT::v16i8 || + EXT1->getOperand(0)->getValueType(0) != MVT::v16i8) + return SDValue(); + + // Pattern is dectected. Let's convert it to sequence of nodes. + SDLoc DL(N); + + // First, create the node pattern of UABD/SABD. + SDValue UABDHigh8Op0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0), + DAG.getConstant(8, DL, MVT::i64)); + SDValue UABDHigh8Op1 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0), + DAG.getConstant(8, DL, MVT::i64)); + SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD, + DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1); + SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8); + + // Second, create the node pattern of UABAL. + SDValue UABDLo8Op0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0), + DAG.getConstant(0, DL, MVT::i64)); + SDValue UABDLo8Op1 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0), + DAG.getConstant(0, DL, MVT::i64)); + SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD, + DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1); + SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8); + SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD); + + // Third, create the node of UADDLP. + SDValue UADDLP = DAG.getNode(AArch64ISD::UADDLP, DL, MVT::v4i32, UABAL); + + // Fourth, create the node of VECREDUCE_ADD. + return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP); +} + // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce // vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one)) // vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B)) static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, const AArch64Subtarget *ST) { + if (!ST->hasDotProd()) + return performVecReduceAddCombineWithUADDLP(N, DAG); + SDValue Op0 = N->getOperand(0); - if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32 || + if (N->getValueType(0) != MVT::i32 || Op0.getValueType().getVectorElementType() != MVT::i32) return SDValue(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index c2ada6f..20872b4 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -240,6 +240,9 @@ enum NodeType : unsigned { UABD, SABD, + // Unsigned Add Long Pairwise + UADDLP, + // udot/sdot instructions UDOT, SDOT, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 33bd0be..c303d87 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -271,6 +271,8 @@ def SDT_AArch64ITOF : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisSameAs<0,1>]>; def SDT_AArch64TLSDescCall : SDTypeProfile<0, -2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>; +def SDT_AArch64uaddlp : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>; + def SDT_AArch64ldp : SDTypeProfile<2, 1, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; def SDT_AArch64stp : SDTypeProfile<0, 3, [SDTCisVT<0, i64>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; def SDT_AArch64stnp : SDTypeProfile<0, 3, [SDTCisVT<0, v4i32>, SDTCisSameAs<0, 1>, SDTCisPtrTy<2>]>; @@ -587,6 +589,11 @@ def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs), [(AArch64sabd_n node:$lhs, node:$rhs), (int_aarch64_neon_sabd node:$lhs, node:$rhs)]>; +def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>; +def AArch64uaddlp : PatFrags<(ops node:$src), + [(AArch64uaddlp_n node:$src), + (int_aarch64_neon_uaddlp node:$src)]>; + def SDT_AArch64SETTAG : SDTypeProfile<0, 2, [SDTCisPtrTy<0>, SDTCisPtrTy<1>]>; def AArch64stg : SDNode<"AArch64ISD::STG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; def AArch64stzg : SDNode<"AArch64ISD::STZG", SDT_AArch64SETTAG, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; @@ -4178,9 +4185,8 @@ defm SQXTN : SIMDMixedTwoVector<0, 0b10100, "sqxtn", int_aarch64_neon_sqxtn>; defm SQXTUN : SIMDMixedTwoVector<1, 0b10010, "sqxtun", int_aarch64_neon_sqxtun>; defm SUQADD : SIMDTwoVectorBHSDTied<0, 0b00011, "suqadd",int_aarch64_neon_suqadd>; defm UADALP : SIMDLongTwoVectorTied<1, 0b00110, "uadalp", - BinOpFrag<(add node:$LHS, (int_aarch64_neon_uaddlp node:$RHS))> >; -defm UADDLP : SIMDLongTwoVector<1, 0b00010, "uaddlp", - int_aarch64_neon_uaddlp>; + BinOpFrag<(add node:$LHS, (AArch64uaddlp node:$RHS))> >; +defm UADDLP : SIMDLongTwoVector<1, 0b00010, "uaddlp", AArch64uaddlp>; defm UCVTF : SIMDTwoVectorIntToFP<1, 0, 0b11101, "ucvtf", uint_to_fp>; defm UQXTN : SIMDMixedTwoVector<1, 0b10100, "uqxtn", int_aarch64_neon_uqxtn>; defm URECPE : SIMDTwoVectorS<0, 1, 0b11100, "urecpe", int_aarch64_neon_urecpe>; diff --git a/llvm/test/CodeGen/AArch64/arm64-vabs.ll b/llvm/test/CodeGen/AArch64/arm64-vabs.ll index a5945bb..4d792bc 100644 --- a/llvm/test/CodeGen/AArch64/arm64-vabs.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vabs.ll @@ -218,12 +218,9 @@ define i16 @uabd16b_rdx(<16 x i8>* %a, <16 x i8>* %b) { define i32 @uabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) { ; CHECK-LABEL: uabd16b_rdx_i32: ; CHECK: // %bb.0: -; CHECK-NEXT: uabd.16b v0, v0, v1 -; CHECK-NEXT: ushll2.8h v1, v0, #0 -; CHECK-NEXT: ushll.8h v0, v0, #0 -; CHECK-NEXT: uaddl2.4s v2, v0, v1 -; CHECK-NEXT: uaddl.4s v0, v0, v1 -; CHECK-NEXT: add.4s v0, v0, v2 +; CHECK-NEXT: uabdl.8h v2, v0, v1 +; CHECK-NEXT: uabal2.8h v2, v0, v1 +; CHECK-NEXT: uaddlp.4s v0, v2 ; CHECK-NEXT: addv.4s s0, v0 ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -240,12 +237,9 @@ define i32 @uabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) { define i32 @sabd16b_rdx_i32(<16 x i8> %a, <16 x i8> %b) { ; CHECK-LABEL: sabd16b_rdx_i32: ; CHECK: // %bb.0: -; CHECK-NEXT: sabd.16b v0, v0, v1 -; CHECK-NEXT: ushll2.8h v1, v0, #0 -; CHECK-NEXT: ushll.8h v0, v0, #0 -; CHECK-NEXT: uaddl2.4s v2, v0, v1 -; CHECK-NEXT: uaddl.4s v0, v0, v1 -; CHECK-NEXT: add.4s v0, v0, v2 +; CHECK-NEXT: sabdl.8h v2, v0, v1 +; CHECK-NEXT: sabal2.8h v2, v0, v1 +; CHECK-NEXT: uaddlp.4s v0, v2 ; CHECK-NEXT: addv.4s s0, v0 ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret diff --git a/llvm/test/CodeGen/AArch64/neon-sad.ll b/llvm/test/CodeGen/AArch64/neon-sad.ll index c5372a2..cfd9712 100644 --- a/llvm/test/CodeGen/AArch64/neon-sad.ll +++ b/llvm/test/CodeGen/AArch64/neon-sad.ll @@ -9,12 +9,9 @@ define i32 @test_sad_v16i8_zext(i8* nocapture readonly %a, i8* nocapture readonl ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr q0, [x0] ; CHECK-NEXT: ldr q1, [x1] -; CHECK-NEXT: uabd v0.16b, v1.16b, v0.16b -; CHECK-NEXT: ushll2 v1.8h, v0.16b, #0 -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: uaddl2 v2.4s, v0.8h, v1.8h -; CHECK-NEXT: uaddl v0.4s, v0.4h, v1.4h -; CHECK-NEXT: add v0.4s, v0.4s, v2.4s +; CHECK-NEXT: uabdl v2.8h, v1.8b, v0.8b +; CHECK-NEXT: uabal2 v2.8h, v1.16b, v0.16b +; CHECK-NEXT: uaddlp v0.4s, v2.8h ; CHECK-NEXT: addv s0, v0.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret @@ -36,12 +33,9 @@ define i32 @test_sad_v16i8_sext(i8* nocapture readonly %a, i8* nocapture readonl ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr q0, [x0] ; CHECK-NEXT: ldr q1, [x1] -; CHECK-NEXT: sabd v0.16b, v1.16b, v0.16b -; CHECK-NEXT: ushll2 v1.8h, v0.16b, #0 -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: uaddl2 v2.4s, v0.8h, v1.8h -; CHECK-NEXT: uaddl v0.4s, v0.4h, v1.4h -; CHECK-NEXT: add v0.4s, v0.4s, v2.4s +; CHECK-NEXT: sabdl v2.8h, v1.8b, v0.8b +; CHECK-NEXT: sabal2 v2.8h, v1.16b, v0.16b +; CHECK-NEXT: uaddlp v0.4s, v2.8h ; CHECK-NEXT: addv s0, v0.4s ; CHECK-NEXT: fmov w0, s0 ; CHECK-NEXT: ret -- 2.7.4