From 672b62ea21dfe5f9bfb2b0362785f2685be830a0 Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Tue, 21 Apr 2020 10:44:46 +0000 Subject: [PATCH] [AArch64][SVE] Custom lowering of floating-point reductions Summary: This patch implements custom floating-point reduction ISD nodes that have vector results, which are used to lower the following intrinsics: * llvm.aarch64.sve.fadda * llvm.aarch64.sve.faddv * llvm.aarch64.sve.fmaxv * llvm.aarch64.sve.fmaxnmv * llvm.aarch64.sve.fminv * llvm.aarch64.sve.fminnmv SVE reduction instructions keep their result within a vector register, with all other bits set to zero. Changes in this patch were implemented by Paul Walker and Sander de Smalen. Reviewers: sdesmalen, efriedma, rengolin Reviewed By: efriedma Differential Revision: https://reviews.llvm.org/D78723 --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 58 ++++++++++++++++++++++ llvm/lib/Target/AArch64/AArch64ISelLowering.h | 8 +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 46 +++++++++++------ llvm/lib/Target/AArch64/SVEInstrFormats.td | 32 ++++++------ .../CodeGen/AArch64/sve-intrinsics-fp-reduce.ll | 2 +- 5 files changed, 113 insertions(+), 33 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9d8151d..e5f3a4e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1366,6 +1366,12 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { case AArch64ISD::REV: return "AArch64ISD::REV"; case AArch64ISD::REINTERPRET_CAST: return "AArch64ISD::REINTERPRET_CAST"; case AArch64ISD::TBL: return "AArch64ISD::TBL"; + case AArch64ISD::FADDA_PRED: return "AArch64ISD::FADDA_PRED"; + case AArch64ISD::FADDV_PRED: return "AArch64ISD::FADDV_PRED"; + case AArch64ISD::FMAXV_PRED: return "AArch64ISD::FMAXV_PRED"; + case AArch64ISD::FMAXNMV_PRED: return "AArch64ISD::FMAXNMV_PRED"; + case AArch64ISD::FMINV_PRED: return "AArch64ISD::FMINV_PRED"; + case AArch64ISD::FMINNMV_PRED: return "AArch64ISD::FMINNMV_PRED"; case AArch64ISD::NOT: return "AArch64ISD::NOT"; case AArch64ISD::BIT: return "AArch64ISD::BIT"; case AArch64ISD::CBZ: return "AArch64ISD::CBZ"; @@ -11308,6 +11314,46 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, return DAG.getZExtOrTrunc(Res, DL, VT); } +static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDLoc DL(N); + + SDValue Pred = N->getOperand(1); + SDValue VecToReduce = N->getOperand(2); + + EVT ReduceVT = VecToReduce.getValueType(); + SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce); + + // SVE reductions set the whole vector register with the first element + // containing the reduction result, which we'll now extract. + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce, + Zero); +} + +static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDLoc DL(N); + + SDValue Pred = N->getOperand(1); + SDValue InitVal = N->getOperand(2); + SDValue VecToReduce = N->getOperand(3); + EVT ReduceVT = VecToReduce.getValueType(); + + // Ordered reductions use the first lane of the result vector as the + // reduction's initial value. + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + InitVal = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ReduceVT, + DAG.getUNDEF(ReduceVT), InitVal, Zero); + + SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, InitVal, VecToReduce); + + // SVE reductions set the whole vector register with the first element + // containing the reduction result, which we'll now extract. + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce, + Zero); +} + static SDValue performIntrinsicCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -11391,6 +11437,18 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::aarch64_sve_udiv: return DAG.getNode(AArch64ISD::UDIV_PRED, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2), N->getOperand(3)); + case Intrinsic::aarch64_sve_fadda: + return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG); + case Intrinsic::aarch64_sve_faddv: + return combineSVEReductionFP(N, AArch64ISD::FADDV_PRED, DAG); + case Intrinsic::aarch64_sve_fmaxnmv: + return combineSVEReductionFP(N, AArch64ISD::FMAXNMV_PRED, DAG); + case Intrinsic::aarch64_sve_fmaxv: + return combineSVEReductionFP(N, AArch64ISD::FMAXV_PRED, DAG); + case Intrinsic::aarch64_sve_fminnmv: + return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG); + case Intrinsic::aarch64_sve_fminv: + return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG); case Intrinsic::aarch64_sve_sel: return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0), N->getOperand(1), N->getOperand(2), N->getOperand(3)); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index fe67d75..c5afa9b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -215,6 +215,14 @@ enum NodeType : unsigned { REV, TBL, + // Floating-point reductions. + FADDA_PRED, + FADDV_PRED, + FMAXV_PRED, + FMAXNMV_PRED, + FMINV_PRED, + FMINNMV_PRED, + INSR, PTEST, PTRUE, diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index ed79063..7c964da 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -134,16 +134,20 @@ def sve_cntw_imm_neg : ComplexPattern">; def sve_cntd_imm_neg : ComplexPattern">; def SDT_AArch64Reduce : SDTypeProfile<1, 2, [SDTCisVec<1>, SDTCisVec<2>]>; - -def AArch64smaxv_pred : SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>; -def AArch64umaxv_pred : SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>; -def AArch64sminv_pred : SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>; -def AArch64uminv_pred : SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>; -def AArch64orv_pred : SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>; -def AArch64eorv_pred : SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>; -def AArch64andv_pred : SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>; -def AArch64lasta : SDNode<"AArch64ISD::LASTA", SDT_AArch64Reduce>; -def AArch64lastb : SDNode<"AArch64ISD::LASTB", SDT_AArch64Reduce>; +def AArch64faddv_pred : SDNode<"AArch64ISD::FADDV_PRED", SDT_AArch64Reduce>; +def AArch64fmaxv_pred : SDNode<"AArch64ISD::FMAXV_PRED", SDT_AArch64Reduce>; +def AArch64fmaxnmv_pred : SDNode<"AArch64ISD::FMAXNMV_PRED", SDT_AArch64Reduce>; +def AArch64fminv_pred : SDNode<"AArch64ISD::FMINV_PRED", SDT_AArch64Reduce>; +def AArch64fminnmv_pred : SDNode<"AArch64ISD::FMINNMV_PRED", SDT_AArch64Reduce>; +def AArch64smaxv_pred : SDNode<"AArch64ISD::SMAXV_PRED", SDT_AArch64Reduce>; +def AArch64umaxv_pred : SDNode<"AArch64ISD::UMAXV_PRED", SDT_AArch64Reduce>; +def AArch64sminv_pred : SDNode<"AArch64ISD::SMINV_PRED", SDT_AArch64Reduce>; +def AArch64uminv_pred : SDNode<"AArch64ISD::UMINV_PRED", SDT_AArch64Reduce>; +def AArch64orv_pred : SDNode<"AArch64ISD::ORV_PRED", SDT_AArch64Reduce>; +def AArch64eorv_pred : SDNode<"AArch64ISD::EORV_PRED", SDT_AArch64Reduce>; +def AArch64andv_pred : SDNode<"AArch64ISD::ANDV_PRED", SDT_AArch64Reduce>; +def AArch64lasta : SDNode<"AArch64ISD::LASTA", SDT_AArch64Reduce>; +def AArch64lastb : SDNode<"AArch64ISD::LASTB", SDT_AArch64Reduce>; def SDT_AArch64DIV : SDTypeProfile<1, 3, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>, @@ -156,6 +160,7 @@ def AArch64udiv_pred : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64DIV>; def SDT_AArch64ReduceWithInit : SDTypeProfile<1, 3, [SDTCisVec<1>, SDTCisVec<3>]>; def AArch64clasta_n : SDNode<"AArch64ISD::CLASTA_N", SDT_AArch64ReduceWithInit>; def AArch64clastb_n : SDNode<"AArch64ISD::CLASTB_N", SDT_AArch64ReduceWithInit>; +def AArch64fadda_pred : SDNode<"AArch64ISD::FADDA_PRED", SDT_AArch64ReduceWithInit>; def SDT_AArch64Rev : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; def AArch64rev : SDNode<"AArch64ISD::REV", SDT_AArch64Rev>; @@ -352,12 +357,21 @@ let Predicates = [HasSVE] in { defm FMUL_ZZZI : sve_fp_fmul_by_indexed_elem<"fmul", int_aarch64_sve_fmul_lane>; // SVE floating point reductions. - defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda", int_aarch64_sve_fadda>; - defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv", int_aarch64_sve_faddv>; - defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", int_aarch64_sve_fmaxnmv>; - defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", int_aarch64_sve_fminnmv>; - defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv", int_aarch64_sve_fmaxv>; - defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv", int_aarch64_sve_fminv>; + defm FADDA_VPZ : sve_fp_2op_p_vd<0b000, "fadda", AArch64fadda_pred>; + defm FADDV_VPZ : sve_fp_fast_red<0b000, "faddv", AArch64faddv_pred>; + defm FMAXNMV_VPZ : sve_fp_fast_red<0b100, "fmaxnmv", AArch64fmaxnmv_pred>; + defm FMINNMV_VPZ : sve_fp_fast_red<0b101, "fminnmv", AArch64fminnmv_pred>; + defm FMAXV_VPZ : sve_fp_fast_red<0b110, "fmaxv", AArch64fmaxv_pred>; + defm FMINV_VPZ : sve_fp_fast_red<0b111, "fminv", AArch64fminv_pred>; + + // Use more efficient NEON instructions to extract elements within the NEON + // part (first 128bits) of an SVE register. + def : Pat<(vector_extract (nxv8f16 ZPR:$Zs), (i64 0)), + (f16 (EXTRACT_SUBREG (v8f16 (EXTRACT_SUBREG ZPR:$Zs, zsub)), hsub))>; + def : Pat<(vector_extract (nxv4f32 ZPR:$Zs), (i64 0)), + (f32 (EXTRACT_SUBREG (v4f32 (EXTRACT_SUBREG ZPR:$Zs, zsub)), ssub))>; + def : Pat<(vector_extract (nxv2f64 ZPR:$Zs), (i64 0)), + (f64 (EXTRACT_SUBREG (v2f64 (EXTRACT_SUBREG ZPR:$Zs, zsub)), dsub))>; // Splat immediate (unpredicated) defm DUP_ZI : sve_int_dup_imm<"dup">; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 54fab60..a75bf47 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -4444,8 +4444,8 @@ multiclass sve2_int_while_rr rw, string asm, string op> { //===----------------------------------------------------------------------===// class sve_fp_fast_red sz, bits<3> opc, string asm, - ZPRRegOp zprty, RegisterClass dstRegClass> -: I<(outs dstRegClass:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn), + ZPRRegOp zprty, FPRasZPROperand dstOpType> +: I<(outs dstOpType:$Vd), (ins PPR3bAny:$Pg, zprty:$Zn), asm, "\t$Vd, $Pg, $Zn", "", []>, Sched<[]> { @@ -4463,13 +4463,13 @@ class sve_fp_fast_red sz, bits<3> opc, string asm, } multiclass sve_fp_fast_red opc, string asm, SDPatternOperator op> { - def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16>; - def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32>; - def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64>; + def _H : sve_fp_fast_red<0b01, opc, asm, ZPR16, FPR16asZPR>; + def _S : sve_fp_fast_red<0b10, opc, asm, ZPR32, FPR32asZPR>; + def _D : sve_fp_fast_red<0b11, opc, asm, ZPR64, FPR64asZPR>; - def : SVE_2_Op_Pat(NAME # _H)>; - def : SVE_2_Op_Pat(NAME # _S)>; - def : SVE_2_Op_Pat(NAME # _D)>; + def : SVE_2_Op_Pat(NAME # _H)>; + def : SVE_2_Op_Pat(NAME # _S)>; + def : SVE_2_Op_Pat(NAME # _D)>; } @@ -4478,8 +4478,8 @@ multiclass sve_fp_fast_red opc, string asm, SDPatternOperator op> { //===----------------------------------------------------------------------===// class sve_fp_2op_p_vd sz, bits<3> opc, string asm, - ZPRRegOp zprty, RegisterClass dstRegClass> -: I<(outs dstRegClass:$Vdn), (ins PPR3bAny:$Pg, dstRegClass:$_Vdn, zprty:$Zm), + ZPRRegOp zprty, FPRasZPROperand dstOpType> +: I<(outs dstOpType:$Vdn), (ins PPR3bAny:$Pg, dstOpType:$_Vdn, zprty:$Zm), asm, "\t$Vdn, $Pg, $_Vdn, $Zm", "", []>, @@ -4500,13 +4500,13 @@ class sve_fp_2op_p_vd sz, bits<3> opc, string asm, } multiclass sve_fp_2op_p_vd opc, string asm, SDPatternOperator op> { - def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16>; - def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32>; - def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64>; + def _H : sve_fp_2op_p_vd<0b01, opc, asm, ZPR16, FPR16asZPR>; + def _S : sve_fp_2op_p_vd<0b10, opc, asm, ZPR32, FPR32asZPR>; + def _D : sve_fp_2op_p_vd<0b11, opc, asm, ZPR64, FPR64asZPR>; - def : SVE_3_Op_Pat(NAME # _H)>; - def : SVE_3_Op_Pat(NAME # _S)>; - def : SVE_3_Op_Pat(NAME # _D)>; + def : SVE_3_Op_Pat(NAME # _H)>; + def : SVE_3_Op_Pat(NAME # _S)>; + def : SVE_3_Op_Pat(NAME # _D)>; } //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll index 083a7d3..c933c2ea 100644 --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-reduce.ll @@ -1,4 +1,4 @@ -; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve -asm-verbose=0 < %s | FileCheck %s ; ; FADDA -- 2.7.4