[RISCV] Share reduction lowering code for vp.reduce
authorPhilip Reames <preames@rivosinc.com>
Fri, 9 Dec 2022 19:56:20 +0000 (11:56 -0800)
committerPhilip Reames <listmail@philipreames.com>
Fri, 9 Dec 2022 20:22:59 +0000 (12:22 -0800)
We can consolidate code and clarify edge case behavior at the same time.

There are two functional differences here.

First, I remove the ResVT handling, and always use the reduction element type. This appears to be dead code. There's no test coverage, and this code doesn't need to account for scalar type legalization anyways.

Second, if the VL happens to be known non-zero, we can avoid passing through start. This is mostly needed to allow reuse of the existing code; I don't consider it interesting as an optimization on it's own.

Differential Revision: https://reviews.llvm.org/D139733

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

index 11461fe..69e1c8c 100644 (file)
@@ -5796,6 +5796,13 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
   return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
 }
 
+static bool hasNonZeroAVL(SDValue AVL) {
+  auto *RegisterAVL = dyn_cast<RegisterSDNode>(AVL);
+  auto *ImmAVL = dyn_cast<ConstantSDNode>(AVL);
+  return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) ||
+         (ImmAVL && ImmAVL->getZExtValue() >= 1);
+}
+
 /// Helper to lower a reduction sequence of the form:
 /// scalar = reduce_op vec, scalar_start
 static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL,
@@ -5808,7 +5815,8 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue
   SDValue InitialSplat =
       lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
                        M1VT, DL, DAG, Subtarget);
-  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
+  SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat;
+  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
                                   InitialSplat, Mask, VL);
   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
                      DAG.getConstant(0, DL, XLenVT));
@@ -5951,29 +5959,17 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
     return SDValue();
 
   MVT VecVT = VecEVT.getSimpleVT();
-  MVT VecEltVT = VecVT.getVectorElementType();
   unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode());
 
-  MVT ContainerVT = VecVT;
   if (VecVT.isFixedLengthVector()) {
-    ContainerVT = getContainerForFixedLengthVector(VecVT);
+    auto ContainerVT = getContainerForFixedLengthVector(VecVT);
     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
   }
 
   SDValue VL = Op.getOperand(3);
   SDValue Mask = Op.getOperand(2);
-
-  MVT M1VT = getLMUL1VT(ContainerVT);
-  MVT XLenVT = Subtarget.getXLenVT();
-  MVT ResVT = !VecVT.isInteger() || VecEltVT.bitsGE(XLenVT) ? VecEltVT : XLenVT;
-
-  SDValue StartSplat = lowerScalarSplat(SDValue(), Op.getOperand(0),
-                                        DAG.getConstant(1, DL, XLenVT), M1VT,
-                                        DL, DAG, Subtarget);
-  SDValue Reduction =
-      DAG.getNode(RVVOpcode, DL, M1VT, StartSplat, Vec, StartSplat, Mask, VL);
-  SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
-                             DAG.getConstant(0, DL, XLenVT));
+  SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL,
+                                   DL, DAG, Subtarget);
   if (!VecVT.isInteger())
     return Elt0;
   return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());