return convertFromScalableVector(VT, Res, DAG, Subtarget);
}
+static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(N->getOpcode() == RISCVISD::ADD_VL);
+ SDValue Addend = N->getOperand(0);
+ SDValue MulOp = N->getOperand(1);
+ SDValue AddMergeOp = N->getOperand(2);
+
+ if (!AddMergeOp.isUndef())
+ return SDValue();
+
+ auto IsVWMulOpc = [](unsigned Opc) {
+ switch (Opc) {
+ case RISCVISD::VWMUL_VL:
+ case RISCVISD::VWMULU_VL:
+ case RISCVISD::VWMULSU_VL:
+ return true;
+ default:
+ return false;
+ }
+ };
+
+ if (!IsVWMulOpc(MulOp.getOpcode()))
+ std::swap(Addend, MulOp);
+
+ if (!IsVWMulOpc(MulOp.getOpcode()))
+ return SDValue();
+
+ SDValue MulMergeOp = MulOp.getOperand(2);
+
+ if (!MulMergeOp.isUndef())
+ return SDValue();
+
+ SDValue AddMask = N->getOperand(3);
+ SDValue AddVL = N->getOperand(4);
+ SDValue MulMask = MulOp.getOperand(3);
+ SDValue MulVL = MulOp.getOperand(4);
+
+ if (AddMask != MulMask || AddVL != MulVL)
+ return SDValue();
+
+ unsigned Opc = RISCVISD::VWMACC_VL + MulOp.getOpcode() - RISCVISD::VWMUL_VL;
+ static_assert(RISCVISD::VWMACC_VL + 1 == RISCVISD::VWMACCU_VL,
+ "Unexpected opcode after VWMACC_VL");
+ static_assert(RISCVISD::VWMACC_VL + 2 == RISCVISD::VWMACCSU_VL,
+ "Unexpected opcode after VWMACC_VL!");
+ static_assert(RISCVISD::VWMUL_VL + 1 == RISCVISD::VWMULU_VL,
+ "Unexpected opcode after VWMUL_VL!");
+ static_assert(RISCVISD::VWMUL_VL + 2 == RISCVISD::VWMULSU_VL,
+ "Unexpected opcode after VWMUL_VL!");
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ SDValue Ops[] = {MulOp.getOperand(0), MulOp.getOperand(1), Addend, AddMask,
+ AddVL};
+ return DAG.getNode(Opc, DL, VT, Ops);
+}
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
break;
}
case RISCVISD::ADD_VL:
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+ return V;
+ return combineToVWMACC(N, DAG, Subtarget);
case RISCVISD::SUB_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
NODE_NAME_CASE(VFWSUB_VL)
NODE_NAME_CASE(VFWADD_W_VL)
NODE_NAME_CASE(VFWSUB_W_VL)
+ NODE_NAME_CASE(VWMACC_VL)
+ NODE_NAME_CASE(VWMACCU_VL)
+ NODE_NAME_CASE(VWMACCSU_VL)
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWIntBinOp_VL, []>;
def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWIntBinOp_VL, []>;
+def SDT_RISCVVWIntTernOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
+ SDTCisInt<1>,
+ SDTCisSameNumEltsAs<0, 1>,
+ SDTCisOpSmallerThanOp<1, 0>,
+ SDTCisSameAs<1, 2>,
+ SDTCisSameAs<0, 3>,
+ SDTCisSameNumEltsAs<1, 4>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisVT<5, XLenVT>]>;
+def riscv_vwmacc_vl : SDNode<"RISCVISD::VWMACC_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccu_vl : SDNode<"RISCVISD::VWMACCU_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>;
+def riscv_vwmaccsu_vl : SDNode<"RISCVISD::VWMACCSU_VL", SDT_RISCVVWIntTernOp_VL, []>;
+
def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>,
SDTCisFP<1>,
SDTCisSameNumEltsAs<0, 1>,
}
}
-multiclass VPatWidenMultiplyAddVL_VV_VX<PatFrag op1, string instruction_name> {
+multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> {
foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
- def : Pat<(wti.Vector
- (riscv_add_vl wti.RegClass:$rd,
- (op1 vti.RegClass:$rs1,
- (vti.Vector vti.RegClass:$rs2),
- srcvalue, (vti.Mask true_mask), VLOpFrag),
- srcvalue, (vti.Mask true_mask), VLOpFrag)),
- (!cast<Instruction>(instruction_name#"_VV_" # vti.LMul.MX)
- wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
- GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
- def : Pat<(wti.Vector
- (riscv_add_vl wti.RegClass:$rd,
- (op1 (SplatPat XLenVT:$rs1),
- (vti.Vector vti.RegClass:$rs2),
- srcvalue, (vti.Mask true_mask), VLOpFrag),
- srcvalue, (vti.Mask true_mask), VLOpFrag)),
- (!cast<Instruction>(instruction_name#"_VX_" # vti.LMul.MX)
- wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
- GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ def : Pat<(vwmacc_op (vti.Vector vti.RegClass:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector wti.RegClass:$rd),
+ (vti.Mask V0), VLOpFrag),
+ (!cast<Instruction>(instr_name#"_VV_"#vti.LMul.MX#"_MASK")
+ wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+ (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector wti.RegClass:$rd),
+ (vti.Mask V0), VLOpFrag),
+ (!cast<Instruction>(instr_name#"_VX_"#vti.LMul.MX#"_MASK")
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs1,
+ vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW,
+ TAIL_AGNOSTIC)>;
}
}
}
defm : VPatMultiplyAccVL_VV_VX<riscv_sub_vl_oneuse, "PseudoVNMSAC">;
// 11.14. Vector Widening Integer Multiply-Add Instructions
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmul_vl_oneuse, "PseudoVWMACC">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulu_vl_oneuse, "PseudoVWMACCU">;
-defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmulsu_vl_oneuse, "PseudoVWMACCSU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmacc_vl, "PseudoVWMACC">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccu_vl, "PseudoVWMACCU">;
+defm : VPatWidenMultiplyAddVL_VV_VX<riscv_vwmaccsu_vl, "PseudoVWMACCSU">;
foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in
- def : Pat<(wti.Vector
- (riscv_add_vl wti.RegClass:$rd,
- (riscv_vwmulsu_vl_oneuse (vti.Vector vti.RegClass:$rs1),
- (SplatPat XLenVT:$rs2),
- srcvalue,
- (vti.Mask true_mask),
- VLOpFrag),
- srcvalue, (vti.Mask true_mask),VLOpFrag)),
- (!cast<Instruction>("PseudoVWMACCUS_VX_" # vti.LMul.MX)
- wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
- GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ def : Pat<(riscv_vwmaccsu_vl (vti.Vector vti.RegClass:$rs1),
+ (SplatPat XLenVT:$rs2),
+ (wti.Vector wti.RegClass:$rd),
+ (vti.Mask V0), VLOpFrag),
+ (!cast<Instruction>("PseudoVWMACCUS_VX_"#vti.LMul.MX#"_MASK")
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1,
+ (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
}
// 11.15. Vector Integer Merge Instructions