From f60d3ec0c7fd324bccf2275c8f28c390b2b5f069 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 8 Nov 2021 12:07:26 +0000 Subject: [PATCH] [DAG] Add BuildVectorSDNode::getConstantRawBits helper We have several places where we need to extract the raw bits data from a BUILD_VECTOR node, so consolidate this to a single helper function that handles Undefs and Integer/FP constants, including implicit truncation. This should make it easier to extend D113202 to handle more constant folding of bitcasted constant data. Differential Revision: https://reviews.llvm.org/D113351 --- llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 8 +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 70 ++++++-------------------- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 67 ++++++++++++++++++++++++ llvm/lib/Target/X86/X86ISelLowering.cpp | 41 ++++----------- 4 files changed, 100 insertions(+), 86 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index cc00af9..c2c5dbc 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2049,6 +2049,14 @@ public: int32_t getConstantFPSplatPow2ToLog2Int(BitVector *UndefElements, uint32_t BitWidth) const; + /// Extract the raw bit data from a build vector of Undef, Constant or + /// ConstantFP node elements. Each raw bit element will be \p + /// DstEltSizeInBits wide, undef elements are treated as zero, and entirely + /// undefined elements are flagged in \p UndefElements. + bool getConstantRawBits(bool IsLittleEndian, unsigned DstEltSizeInBits, + SmallVectorImpl &RawBitElements, + BitVector &UndefElements) const; + bool isConstant() const; static bool classof(const SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index d01e3b3..9f40a02 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13039,68 +13039,30 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT); } - SDLoc DL(BV); - // Okay, we know the src/dst types are both integers of differing types. - // Handling growing first. assert(SrcEltVT.isInteger() && DstEltVT.isInteger()); - if (SrcBitSize < DstBitSize) { - unsigned NumInputsPerOutput = DstBitSize/SrcBitSize; - SmallVector Ops; - for (unsigned i = 0, e = BV->getNumOperands(); i != e; - i += NumInputsPerOutput) { - bool isLE = DAG.getDataLayout().isLittleEndian(); - APInt NewBits = APInt(DstBitSize, 0); - bool EltIsUndef = true; - for (unsigned j = 0; j != NumInputsPerOutput; ++j) { - // Shift the previously computed bits over. - NewBits <<= SrcBitSize; - SDValue Op = BV->getOperand(i+ (isLE ? (NumInputsPerOutput-j-1) : j)); - if (Op.isUndef()) continue; - EltIsUndef = false; - - NewBits |= cast(Op)->getAPIntValue(). - zextOrTrunc(SrcBitSize).zext(DstBitSize); - } + // TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a + // BuildVectorSDNode? + auto *BVN = cast(BV); - if (EltIsUndef) - Ops.push_back(DAG.getUNDEF(DstEltVT)); - else - Ops.push_back(DAG.getConstant(NewBits, DL, DstEltVT)); - } - - EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size()); - return DAG.getBuildVector(VT, DL, Ops); - } + // Extract the constant raw bit data. + BitVector UndefElements; + SmallVector RawBits; + bool IsLE = DAG.getDataLayout().isLittleEndian(); + if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements)) + return SDValue(); - // Finally, this must be the case where we are shrinking elements: each input - // turns into multiple outputs. - unsigned NumOutputsPerInput = SrcBitSize/DstBitSize; - EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, - NumOutputsPerInput*BV->getNumOperands()); + SDLoc DL(BV); SmallVector Ops; - - for (const SDValue &Op : BV->op_values()) { - if (Op.isUndef()) { - Ops.append(NumOutputsPerInput, DAG.getUNDEF(DstEltVT)); - continue; - } - - APInt OpVal = cast(Op)-> - getAPIntValue().zextOrTrunc(SrcBitSize); - - for (unsigned j = 0; j != NumOutputsPerInput; ++j) { - APInt ThisVal = OpVal.trunc(DstBitSize); - Ops.push_back(DAG.getConstant(ThisVal, DL, DstEltVT)); - OpVal.lshrInPlace(DstBitSize); - } - - // For big endian targets, swap the order of the pieces of each element. - if (DAG.getDataLayout().isBigEndian()) - std::reverse(Ops.end()-NumOutputsPerInput, Ops.end()); + for (unsigned I = 0, E = RawBits.size(); I != E; ++I) { + if (UndefElements[I]) + Ops.push_back(DAG.getUNDEF(DstEltVT)); + else + Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT)); } + EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size()); return DAG.getBuildVector(VT, DL, Ops); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 9667389..6739f53 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -10916,6 +10916,73 @@ BuildVectorSDNode::getConstantFPSplatPow2ToLog2Int(BitVector *UndefElements, return -1; } +bool BuildVectorSDNode::getConstantRawBits( + bool IsLittleEndian, unsigned DstEltSizeInBits, + SmallVectorImpl &RawBitElements, BitVector &UndefElements) const { + // Early-out if this contains anything but Undef/Constant/ConstantFP. + if (!isConstant()) + return false; + + unsigned NumSrcOps = getNumOperands(); + unsigned SrcEltSizeInBits = getValueType(0).getScalarSizeInBits(); + assert(((NumSrcOps * SrcEltSizeInBits) % DstEltSizeInBits) == 0 && + "Invalid bitcast scale"); + + unsigned NumDstOps = (NumSrcOps * SrcEltSizeInBits) / DstEltSizeInBits; + UndefElements.clear(); + UndefElements.resize(NumDstOps, false); + RawBitElements.assign(NumDstOps, APInt::getNullValue(DstEltSizeInBits)); + + // Concatenate src elements constant bits together into dst element. + if (SrcEltSizeInBits <= DstEltSizeInBits) { + unsigned Scale = DstEltSizeInBits / SrcEltSizeInBits; + for (unsigned I = 0; I != NumDstOps; ++I) { + UndefElements.set(I); + APInt &RawBits = RawBitElements[I]; + for (unsigned J = 0; J != Scale; ++J) { + unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1)); + SDValue Op = getOperand(Idx); + if (Op.isUndef()) + continue; + UndefElements.reset(I); + auto *CInt = dyn_cast(Op); + auto *CFP = dyn_cast(Op); + assert((CInt || CFP) && "Unknown constant"); + APInt EltBits = + CInt ? CInt->getAPIntValue().truncOrSelf(SrcEltSizeInBits) + : CFP->getValueAPF().bitcastToAPInt(); + assert(EltBits.getBitWidth() == SrcEltSizeInBits && + "Illegal constant bitwidths"); + RawBits.insertBits(EltBits, J * SrcEltSizeInBits); + } + } + return true; + } + + // Split src element constant bits into dst elements. + unsigned Scale = SrcEltSizeInBits / DstEltSizeInBits; + for (unsigned I = 0; I != NumSrcOps; ++I) { + SDValue Op = getOperand(I); + if (Op.isUndef()) { + UndefElements.set(I * Scale, (I + 1) * Scale); + continue; + } + auto *CInt = dyn_cast(Op); + auto *CFP = dyn_cast(Op); + assert((CInt || CFP) && "Unknown constant"); + APInt EltBits = + CInt ? CInt->getAPIntValue() : CFP->getValueAPF().bitcastToAPInt(); + + for (unsigned J = 0; J != Scale; ++J) { + unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1)); + APInt &RawBits = RawBitElements[Idx]; + RawBits = EltBits.extractBits(DstEltSizeInBits, J * DstEltSizeInBits); + } + } + + return true; +} + bool BuildVectorSDNode::isConstant() const { for (const SDValue &Op : op_values()) { unsigned Opc = Op.getOpcode(); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index ba2a16c..cf76491 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -6879,40 +6879,17 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, } // Extract constant bits from build vector. - if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) { + if (auto *BV = dyn_cast(Op)) { + BitVector Undefs; + SmallVector SrcEltBits; unsigned SrcEltSizeInBits = VT.getScalarSizeInBits(); - unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits; - - APInt UndefSrcElts(NumSrcElts, 0); - SmallVector SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0)); - for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { - const SDValue &Src = Op.getOperand(i); - if (Src.isUndef()) { - UndefSrcElts.setBit(i); - continue; - } - auto *Cst = cast(Src); - SrcEltBits[i] = Cst->getAPIntValue().zextOrTrunc(SrcEltSizeInBits); + if (BV->getConstantRawBits(true, SrcEltSizeInBits, SrcEltBits, Undefs)) { + APInt UndefSrcElts = APInt::getNullValue(SrcEltBits.size()); + for (unsigned I = 0, E = SrcEltBits.size(); I != E; ++I) + if (Undefs[I]) + UndefSrcElts.setBit(I); + return CastBitData(UndefSrcElts, SrcEltBits); } - return CastBitData(UndefSrcElts, SrcEltBits); - } - if (ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) { - unsigned SrcEltSizeInBits = VT.getScalarSizeInBits(); - unsigned NumSrcElts = SizeInBits / SrcEltSizeInBits; - - APInt UndefSrcElts(NumSrcElts, 0); - SmallVector SrcEltBits(NumSrcElts, APInt(SrcEltSizeInBits, 0)); - for (unsigned i = 0, e = Op.getNumOperands(); i != e; ++i) { - const SDValue &Src = Op.getOperand(i); - if (Src.isUndef()) { - UndefSrcElts.setBit(i); - continue; - } - auto *Cst = cast(Src); - APInt RawBits = Cst->getValueAPF().bitcastToAPInt(); - SrcEltBits[i] = RawBits.zextOrTrunc(SrcEltSizeInBits); - } - return CastBitData(UndefSrcElts, SrcEltBits); } // Extract constant bits from constant pool vector. -- 2.7.4