From a200b0fc256a890b3f72014d20fce9e49d75763b Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Mon, 3 Oct 2022 12:18:21 -0700 Subject: [PATCH] [DAG] Introduce getSplat utility for common dispatch pattern [nfc] We have a very common pattern of dispatching between BUILD_VECTOR and SPLAT_VECTOR creation repeated in many cases in code. Common the pattern into a utility function. --- llvm/include/llvm/CodeGen/SelectionDAG.h | 10 +++++++++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 18 +++++----------- .../lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp | 8 ++------ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 ++++-------- .../CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 24 +++++----------------- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 4 +--- 6 files changed, 27 insertions(+), 50 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 969199b..b6f71ab 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -862,6 +862,16 @@ public: return getNode(ISD::SPLAT_VECTOR, DL, VT, Op); } + /// Returns a node representing a splat of one value into all lanes + /// of the provided vector type. This is a utility which returns + /// either a BUILD_VECTOR or SPLAT_VECTOR depending on the + /// scalability of the desired vector type. + SDValue getSplat(EVT VT, const SDLoc &DL, SDValue Op) { + assert(VT.isVector() && "Can't splat to non-vector type"); + return VT.isScalableVector() ? + getSplatVector(VT, DL, Op) : getSplatBuildVector(VT, DL, Op); + } + /// Returns a vector of type ResVT whose elements contain the linear sequence /// <0, Step, Step * 2, Step * 3, ...> SDValue getStepVector(const SDLoc &DL, EVT ResVT, APInt StepVal); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index a6c9c46..9281b4e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -3469,11 +3469,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { if (VT.isVector()) { SDValue N1S = DAG.getSplatValue(N1, true); if (N1S && N1S.getOpcode() == ISD::SUB && - isNullConstant(N1S.getOperand(0))) { - if (VT.isScalableVector()) - return DAG.getSplatVector(VT, DL, N1S.getOperand(1)); - return DAG.getSplatBuildVector(VT, DL, N1S.getOperand(1)); - } + isNullConstant(N1S.getOperand(0))) + return DAG.getSplat(VT, DL, N1S.getOperand(1)); } } @@ -19778,11 +19775,8 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { if (!IndexC) { // If this is variable insert to undef vector, it might be better to splat: // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... > - if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) { - if (VT.isScalableVector()) - return DAG.getSplatVector(VT, DL, InVal); - return DAG.getSplatBuildVector(VT, DL, InVal); - } + if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) + return DAG.getSplat(VT, DL, InVal); return SDValue(); } @@ -23817,9 +23811,7 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG, } // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index - if (VT.isScalableVector()) - return DAG.getSplatVector(VT, DL, ScalarBO); - return DAG.getSplatBuildVector(VT, DL, ScalarBO); + return DAG.getSplat(VT, DL, ScalarBO); } /// Visit a binary vector operation, like ADD. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 0132bf4..6f0cde6 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -963,10 +963,7 @@ SDValue VectorLegalizer::ExpandSELECT(SDNode *Node) { DAG.getConstant(0, DL, BitTy)); // Broadcast the mask so that the entire vector is all one or all zero. - if (VT.isFixedLengthVector()) - Mask = DAG.getSplatBuildVector(MaskTy, DL, Mask); - else - Mask = DAG.getSplatVector(MaskTy, DL, Mask); + Mask = DAG.getSplat(MaskTy, DL, Mask); // Bitcast the operands to be the same type as the mask. // This is needed when we select between FP types because @@ -1309,8 +1306,7 @@ SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) { return DAG.UnrollVectorOp(Node); SDValue StepVec = DAG.getStepVector(DL, EVLVecVT); - SDValue SplatEVL = IsFixedLen ? DAG.getSplatBuildVector(EVLVecVT, DL, EVL) - : DAG.getSplatVector(EVLVecVT, DL, EVL); + SDValue SplatEVL = DAG.getSplat(EVLVecVT, DL, EVL); SDValue EVLMask = DAG.getSetCC(DL, MaskVT, StepVec, SplatEVL, ISD::CondCode::SETULT); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 3c2a116..8070306 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1607,11 +1607,8 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL, } SDValue Result(N, 0); - if (VT.isScalableVector()) - Result = getSplatVector(VT, DL, Result); - else if (VT.isVector()) - Result = getSplatBuildVector(VT, DL, Result); - + if (VT.isVector()) + Result = getSplat(VT, DL, Result); return Result; } @@ -1663,10 +1660,8 @@ SDValue SelectionDAG::getConstantFP(const ConstantFP &V, const SDLoc &DL, } SDValue Result(N, 0); - if (VT.isScalableVector()) - Result = getSplatVector(VT, DL, Result); - else if (VT.isVector()) - Result = getSplatBuildVector(VT, DL, Result); + if (VT.isVector()) + Result = getSplat(VT, DL, Result); NewSDValueDbgMsg(Result, "Creating fp constant: ", this); return Result; } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 60e6303..3308134 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1695,9 +1695,7 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) { else Op = DAG.getConstant(0, getCurSDLoc(), EltVT); - if (isa(VecTy)) - return NodeMap[V] = DAG.getSplatVector(VT, getCurSDLoc(), Op); - return NodeMap[V] = DAG.getSplatBuildVector(VT, getCurSDLoc(), Op); + return NodeMap[V] = DAG.getSplat(VT, getCurSDLoc(), Op); } llvm_unreachable("Unknown vector constant"); @@ -3904,10 +3902,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) { if (IsVectorGEP && !N.getValueType().isVector()) { LLVMContext &Context = *DAG.getContext(); EVT VT = EVT::getVectorVT(Context, N.getValueType(), VectorElementCount); - if (VectorElementCount.isScalable()) - N = DAG.getSplatVector(VT, dl, N); - else - N = DAG.getSplatBuildVector(VT, dl, N); + N = DAG.getSplat(VT, dl, N); } for (gep_type_iterator GTI = gep_type_begin(&I), E = gep_type_end(&I); @@ -3979,10 +3974,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) { if (!IdxN.getValueType().isVector() && IsVectorGEP) { EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(), VectorElementCount); - if (VectorElementCount.isScalable()) - IdxN = DAG.getSplatVector(VT, dl, IdxN); - else - IdxN = DAG.getSplatBuildVector(VT, dl, IdxN); + IdxN = DAG.getSplat(VT, dl, IdxN); } // If the index is smaller or larger than intptr_t, truncate or extend @@ -7247,14 +7239,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, SDValue TripCount = getValue(I.getOperand(1)); auto VecTy = CCVT.changeVectorElementType(ElementVT); - SDValue VectorIndex, VectorTripCount; - if (VecTy.isScalableVector()) { - VectorIndex = DAG.getSplatVector(VecTy, sdl, Index); - VectorTripCount = DAG.getSplatVector(VecTy, sdl, TripCount); - } else { - VectorIndex = DAG.getSplatBuildVector(VecTy, sdl, Index); - VectorTripCount = DAG.getSplatBuildVector(VecTy, sdl, TripCount); - } + SDValue VectorIndex = DAG.getSplat(VecTy, sdl, Index); + SDValue VectorTripCount = DAG.getSplat(VecTy, sdl, TripCount); SDValue VectorStep = DAG.getStepVector(sdl, VecTy); SDValue VectorInduction = DAG.getNode( ISD::UADDSAT, sdl, VecTy, VectorIndex, VectorStep); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index af8fdb7..3b36521 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -4192,9 +4192,7 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const { // Lower vector SELECTs to VSELECTs by splatting the condition. if (VT.isVector()) { MVT SplatCondVT = VT.changeVectorElementType(MVT::i1); - SDValue CondSplat = VT.isScalableVector() - ? DAG.getSplatVector(SplatCondVT, DL, CondV) - : DAG.getSplatBuildVector(SplatCondVT, DL, CondV); + SDValue CondSplat = DAG.getSplat(SplatCondVT, DL, CondV); return DAG.getNode(ISD::VSELECT, DL, VT, CondSplat, TrueV, FalseV); } -- 2.7.4