[RISCV] Consolidate a bit of common logic for forming reductions
authorPhilip Reames <preames@rivosinc.com>
Fri, 9 Dec 2022 16:17:44 +0000 (08:17 -0800)
committerPhilip Reames <listmail@philipreames.com>
Fri, 9 Dec 2022 16:18:51 +0000 (08:18 -0800)
There's several patches in flght which change this code, better to only have one copy.

The VP case is left seperate for the moment as the result value type differs.

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

index 2f10e94..11461fe 100644 (file)
@@ -5796,6 +5796,25 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
   return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
 }
 
+/// Helper to lower a reduction sequence of the form:
+/// scalar = reduce_op vec, scalar_start
+static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL,
+                                 SDLoc DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+  const MVT VecVT = Vec.getSimpleValueType();
+  const MVT VecEltVT = VecVT.getVectorElementType();
+  const MVT M1VT = getLMUL1VT(VecVT);
+  const MVT XLenVT = Subtarget.getXLenVT();
+
+  SDValue InitialSplat =
+      lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
+                       M1VT, DL, DAG, Subtarget);
+  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
+                                  InitialSplat, Mask, VL);
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
+                     DAG.getConstant(0, DL, XLenVT));
+}
+
+
 SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
                                             SelectionDAG &DAG) const {
   SDLoc DL(Op);
@@ -5828,20 +5847,12 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
   }
 
-  MVT M1VT = getLMUL1VT(ContainerVT);
-  MVT XLenVT = Subtarget.getXLenVT();
-
   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
 
   SDValue NeutralElem =
       DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags());
-  SDValue IdentitySplat =
-      lowerScalarSplat(SDValue(), NeutralElem, DAG.getConstant(1, DL, XLenVT),
-                       M1VT, DL, DAG, Subtarget);
-  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
-                                  IdentitySplat, Mask, VL);
-  SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
-                             DAG.getConstant(0, DL, XLenVT));
+  SDValue Elt0 = lowerReductionSeq(RVVOpcode, NeutralElem, Vec, Mask, VL,
+                                   DL, DAG, Subtarget);
   return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());
 }
 
@@ -5892,18 +5903,9 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op,
     VectorVal = convertToScalableVector(ContainerVT, VectorVal, DAG, Subtarget);
   }
 
-  MVT M1VT = getLMUL1VT(VectorVal.getSimpleValueType());
-  MVT XLenVT = Subtarget.getXLenVT();
-
   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
-
-  SDValue ScalarSplat =
-      lowerScalarSplat(SDValue(), ScalarVal, DAG.getConstant(1, DL, XLenVT),
-                       M1VT, DL, DAG, Subtarget);
-  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT),
-                                  VectorVal, ScalarSplat, Mask, VL);
-  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
-                     DAG.getConstant(0, DL, XLenVT));
+  return lowerReductionSeq(RVVOpcode, ScalarVal, VectorVal, Mask, VL, DL, DAG,
+                           Subtarget);
 }
 
 static unsigned getRVVVPReductionOp(unsigned ISDOpcode) {