[RISCV] Introduce RISCVISD::VWMACC(U/SU)_VL opcode
authorNitin John Raj <nitin.raj@sifive.com>
Thu, 15 Jun 2023 02:10:44 +0000 (19:10 -0700)
committerNitin John Raj <nitin.raj@sifive.com>
Fri, 16 Jun 2023 23:11:35 +0000 (16:11 -0700)
Differential Revision: https://reviews.llvm.org/D153057

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll [moved from llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll with 100% similarity]

index 02f3b58..d7b841c 100644 (file)
@@ -12136,6 +12136,63 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   return convertFromScalableVector(VT, Res, DAG, Subtarget);
 }
 
+static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
+                               const RISCVSubtarget &Subtarget) {
+  assert(N->getOpcode() == RISCVISD::ADD_VL);
+  SDValue Addend = N->getOperand(0);
+  SDValue MulOp = N->getOperand(1);
+  SDValue AddMergeOp = N->getOperand(2);
+
+  if (!AddMergeOp.isUndef())
+    return SDValue();
+
+  auto IsVWMulOpc = [](unsigned Opc) {
+    switch (Opc) {
+    case RISCVISD::VWMUL_VL:
+    case RISCVISD::VWMULU_VL:
+    case RISCVISD::VWMULSU_VL:
+      return true;
+    default:
+      return false;
+    }
+  };
+
+  if (!IsVWMulOpc(MulOp.getOpcode()))
+    std::swap(Addend, MulOp);
+
+  if (!IsVWMulOpc(MulOp.getOpcode()))
+    return SDValue();
+
+  SDValue MulMergeOp = MulOp.getOperand(2);
+
+  if (!MulMergeOp.isUndef())
+    return SDValue();
+
+  SDValue AddMask = N->getOperand(3);
+  SDValue AddVL = N->getOperand(4);
+  SDValue MulMask = MulOp.getOperand(3);
+  SDValue MulVL = MulOp.getOperand(4);
+
+  if (AddMask != MulMask || AddVL != MulVL)
+    return SDValue();
+
+  unsigned Opc = RISCVISD::VWMACC_VL + MulOp.getOpcode() - RISCVISD::VWMUL_VL;
+  static_assert(RISCVISD::VWMACC_VL + 1 == RISCVISD::VWMACCU_VL,
+                "Unexpected opcode after VWMACC_VL");
+  static_assert(RISCVISD::VWMACC_VL + 2 == RISCVISD::VWMACCSU_VL,
+                "Unexpected opcode after VWMACC_VL!");
+  static_assert(RISCVISD::VWMUL_VL + 1 == RISCVISD::VWMULU_VL,
+                "Unexpected opcode after VWMUL_VL!");
+  static_assert(RISCVISD::VWMUL_VL + 2 == RISCVISD::VWMULSU_VL,
+                "Unexpected opcode after VWMUL_VL!");
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Ops[] = {MulOp.getOperand(0), MulOp.getOperand(1), Addend, AddMask,
+                   AddVL};
+  return DAG.getNode(Opc, DL, VT, Ops);
+}
+
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -12546,6 +12603,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case RISCVISD::ADD_VL:
+    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+      return V;
+    return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::SUB_VL:
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWADDU_W_VL:
@@ -15683,6 +15743,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VFWSUB_VL)
   NODE_NAME_CASE(VFWADD_W_VL)
   NODE_NAME_CASE(VFWSUB_W_VL)
+  NODE_NAME_CASE(VWMACC_VL)
+  NODE_NAME_CASE(VWMACCU_VL)
+  NODE_NAME_CASE(VWMACCSU_VL)
   NODE_NAME_CASE(VNSRL_VL)
   NODE_NAME_CASE(SETCC_VL)
   NODE_NAME_CASE(VSELECT_VL)
index fb7b029..dddfe87 100644 (file)
@@ -294,6 +294,12 @@ enum NodeType : unsigned {
   VFWADD_W_VL,
   VFWSUB_W_VL,
 
+  // Widening ternary operations with a mask as the fourth operand and VL as the
+  // fifth operand.
+  VWMACC_VL,
+  VWMACCU_VL,
+  VWMACCSU_VL,
+
   // Narrowing logical shift right.
   // Operands are (source, shift, passthru, mask, vl)
   VNSRL_VL,
index abf1290..e17844c 100644 (file)
@@ -395,6 +395,19 @@ def riscv_vwaddu_vl  : SDNode<"RISCVISD::VWADDU_VL",  SDT_RISCVVWIntBinOp_VL, [S
 def riscv_vwsub_vl   : SDNode<"RISCVISD::VWSUB_VL",   SDT_RISCVVWIntBinOp_VL, []>;
 def riscv_vwsubu_vl  : SDNode<"RISCVISD::VWSUBU_VL",  SDT_RISCVVWIntBinOp_VL, []>;
 
+def SDT_RISCVVWIntTernOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
+                                                   SDTCisInt<1>,
+                                                   SDTCisSameNumEltsAs<0, 1>,
+                                                   SDTCisOpSmallerThanOp<1, 0>,
+                                                   SDTCisSameAs<1, 2>,
+                                                   SDTCisSameAs<0, 3>,
+                                                   SDTCisSameNumEltsAs<1, 4>,
+                                                   SDTCVecEltisVT<4, i1>,
+                                                   SDTCisVT<5, XLenVT>]>;
+def riscv_vwmacc_vl : SDNode<"RISCVISD::VWMACC_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccu_vl : SDNode<"RISCVISD::VWMACCU_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccsu_vl : SDNode<"RISCVISD::VWMACCSU_VL", SDT_RISCVVWIntTernOp_VL, []>;
+
 def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
                                                  SDTCisFP<1>,
                                                  SDTCisSameNumEltsAs<0, 1>,
@@ -1407,30 +1420,27 @@ multiclass VPatMultiplyAccVL_VV_VX<PatFrag op, string instruction_name> {
   }
 }
 
-multiclass VPatWidenMultiplyAddVL_VV_VX<PatFrag op1, string instruction_name> {
+multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> {
   foreach vtiTowti = AllWidenableIntVectors in {
     defvar vti = vtiTowti.Vti;
     defvar wti = vtiTowti.Wti;
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                  GetVTypePredicates<wti>.Predicates) in {
-      def : Pat<(wti.Vector
-               (riscv_add_vl wti.RegClass:$rd,
-                             (op1 vti.RegClass:$rs1,
-                                  (vti.Vector vti.RegClass:$rs2),
-                                  srcvalue, (vti.Mask true_mask), VLOpFrag),
-                            srcvalue, (vti.Mask true_mask), VLOpFrag)),
-              (!cast<Instruction>(instruction_name#"_VV_" # vti.LMul.MX)
-                   wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
-                   GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
-      def : Pat<(wti.Vector
-               (riscv_add_vl wti.RegClass:$rd,
-                            (op1 (SplatPat XLenVT:$rs1),
-                                 (vti.Vector vti.RegClass:$rs2),
-                                 srcvalue, (vti.Mask true_mask), VLOpFrag),
-                             srcvalue, (vti.Mask true_mask), VLOpFrag)),
-              (!cast<Instruction>(instruction_name#"_VX_" # vti.LMul.MX)
-                   wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
-                   GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+      def : Pat<(vwmacc_op (vti.Vector vti.RegClass:$rs1),
+                           (vti.Vector vti.RegClass:$rs2),
+                           (wti.Vector wti.RegClass:$rd),
+                           (vti.Mask V0), VLOpFrag),
+                (!cast<Instruction>(instr_name#"_VV_"#vti.LMul.MX#"_MASK")
+                    wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                    (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+      def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1),
+                           (vti.Vector vti.RegClass:$rs2),
+                           (wti.Vector wti.RegClass:$rd),
+                           (vti.Mask V0), VLOpFrag),
+                (!cast<Instruction>(instr_name#"_VX_"#vti.LMul.MX#"_MASK")
+                    wti.RegClass:$rd, vti.ScalarRegClass:$rs1,
+                    vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW,
+                    TAIL_AGNOSTIC)>;
     }
   }
 }
@@ -1704,25 +1714,21 @@ defm : VPatMultiplyAccVL_VV_VX<riscv_add_vl_oneuse, "PseudoVMACC">;
 defm : VPatMultiplyAccVL_VV_VX<riscv_sub_vl_oneuse, "PseudoVNMSAC">;
 
 // 11.14. Vector Widening Integer Multiply-Add Instructions
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmul_vl_oneuse, "PseudoVWMACC">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulu_vl_oneuse, "PseudoVWMACCU">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulsu_vl_oneuse, "PseudoVWMACCSU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmacc_vl, "PseudoVWMACC">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccu_vl, "PseudoVWMACCU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccsu_vl, "PseudoVWMACCSU">;
 foreach vtiTowti = AllWidenableIntVectors in {
   defvar vti = vtiTowti.Vti;
   defvar wti = vtiTowti.Wti;
   let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
                                GetVTypePredicates<wti>.Predicates) in
-  def : Pat<(wti.Vector
-             (riscv_add_vl wti.RegClass:$rd,
-                           (riscv_vwmulsu_vl_oneuse (vti.Vector vti.RegClass:$rs1),
-                                                    (SplatPat XLenVT:$rs2),
-                                                    srcvalue,
-                                                    (vti.Mask true_mask),
-                                                    VLOpFrag),
-                           srcvalue, (vti.Mask true_mask),VLOpFrag)),
-            (!cast<Instruction>("PseudoVWMACCUS_VX_" # vti.LMul.MX)
-                 wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
-                 GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+  def : Pat<(riscv_vwmaccsu_vl (vti.Vector vti.RegClass:$rs1),
+                               (SplatPat XLenVT:$rs2),
+                               (wti.Vector wti.RegClass:$rd),
+                               (vti.Mask V0), VLOpFrag),
+            (!cast<Instruction>("PseudoVWMACCUS_VX_"#vti.LMul.MX#"_MASK")
+                wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
+                (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
 }
 
 // 11.15. Vector Integer Merge Instructions