From 463f50b436a2ac3000a90d273f2ed05893e8864f Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 30 May 2023 14:38:16 -0700 Subject: [PATCH] [RISCV] Add RISCVISD::VFWMUL_VL. Use it to replace isel patterns with a DAG combine. This is more consistent with how we handle integer widening multiply. A follow up patch will add support for matching vfwmul when the multiplicand is being squared. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 35 ++++++++++++++++++++++ llvm/lib/Target/RISCV/RISCVISelLowering.h | 2 ++ llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td | 23 +++++++++++++- 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 3dc04d0..9d02679 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -11355,6 +11355,38 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG) { N->getOperand(2), Mask, VL); } +static SDValue performVFMUL_VLCombine(SDNode *N, SelectionDAG &DAG) { + // FIXME: Ignore strict opcodes for now. + assert(!N->isTargetStrictFPOpcode() && "Unexpected opcode"); + + // Try to form widening multiply. + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + SDValue Merge = N->getOperand(2); + SDValue Mask = N->getOperand(3); + SDValue VL = N->getOperand(4); + + if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL || + Op1.getOpcode() != RISCVISD::FP_EXTEND_VL) + return SDValue(); + + // TODO: Refactor to handle more complex cases similar to + // combineBinOp_VLToVWBinOp_VL. + if (!Op0.hasOneUse() || !Op1.hasOneUse()) + return SDValue(); + + // Check the mask and VL are the same. + if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL || + Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL) + return SDValue(); + + Op0 = Op0.getOperand(0); + Op1 = Op1.getOperand(0); + + return DAG.getNode(RISCVISD::VFWMUL_VL, SDLoc(N), N->getValueType(0), Op0, + Op1, Merge, Mask, VL); +} + static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { assert(N->getOpcode() == ISD::SRA && "Unexpected opcode"); @@ -12229,6 +12261,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case RISCVISD::STRICT_VFMSUB_VL: case RISCVISD::STRICT_VFNMSUB_VL: return performVFMADD_VLCombine(N, DAG); + case RISCVISD::FMUL_VL: + return performVFMUL_VLCombine(N, DAG); case ISD::LOAD: case ISD::STORE: { if (DCI.isAfterLegalizeDAG()) @@ -15339,6 +15373,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(VWADDU_W_VL) NODE_NAME_CASE(VWSUB_W_VL) NODE_NAME_CASE(VWSUBU_W_VL) + NODE_NAME_CASE(VFWMUL_VL) NODE_NAME_CASE(VNSRL_VL) NODE_NAME_CASE(SETCC_VL) NODE_NAME_CASE(VSELECT_VL) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 829ff1f..af6849c 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -284,6 +284,8 @@ enum NodeType : unsigned { VWSUB_W_VL, VWSUBU_W_VL, + VFWMUL_VL, + // Narrowing logical shift right. // Operands are (source, shift, passthru, mask, vl) VNSRL_VL, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index 76e2a2b..b83ae5f 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -388,6 +388,8 @@ def riscv_vwaddu_vl : SDNode<"RISCVISD::VWADDU_VL", SDT_RISCVVWBinOp_VL, [SDNPCo def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWBinOp_VL, []>; def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWBinOp_VL, []>; +def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>; + def SDT_RISCVVNBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisSameNumEltsAs<0, 1>, SDTCisOpSmallerThanOp<0, 1>, @@ -726,6 +728,7 @@ multiclass VPatBinaryWVL_VV_VX { } } } + multiclass VPatBinaryWVL_VV_VX_WV_WX : VPatBinaryWVL_VV_VX { @@ -1346,6 +1349,24 @@ multiclass VPatWidenReductionVL_Ext_VL { + foreach fvtiToFWti = AllWidenableFloatVectors in { + defvar vti = fvtiToFWti.Vti; + defvar wti = fvtiToFWti.Wti; + let Predicates = !listconcat(GetVTypePredicates.Predicates, + GetVTypePredicates.Predicates) in { + defm : VPatBinaryVL_V; + defm : VPatBinaryVL_VF; + } + } +} + multiclass VPatWidenBinaryFPVL_VV_VF { foreach fvtiToFWti = AllWidenableFloatVectors in { defvar fvti = fvtiToFWti.Vti; @@ -1918,7 +1939,7 @@ defm : VPatBinaryFPVL_VV_VF_E; defm : VPatBinaryFPVL_R_VF_E; // 13.5. Vector Widening Floating-Point Multiply Instructions -defm : VPatWidenBinaryFPVL_VV_VF; +defm : VPatBinaryFPWVL_VV_VF; // 13.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions. defm : VPatFPMulAddVL_VV_VF; -- 2.7.4