[DAG] Split BuildVectorSDNode::getConstantRawBits into BuildVectorSDNode::recastRawBi...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 10 Nov 2021 13:06:07 +0000 (13:06 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Wed, 10 Nov 2021 13:06:19 +0000 (13:06 +0000)
NFC refactor of D113351, pulling out the APInt split/merge code from the BuildVectorSDNode bits extraction into a BuildVectorSDNode::recastRawBits helper. This is to allow us to reuse the code when we're packing constant folded APInt data back together.

llvm/include/llvm/CodeGen/SelectionDAGNodes.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

index c2c5dbc264785253420040a43caebd3c3588d354..115ea9cfedf1496a512f76d10fa401bd67e5a1be 100644 (file)
@@ -2059,6 +2059,15 @@ public:
 
   bool isConstant() const;
 
+  /// Recast bit data \p SrcBitElements to \p DstEltSizeInBits wide elements.
+  /// Undef elements are treated as zero, and entirely undefined elements are
+  /// flagged in \p DstUndefElements.
+  static void recastRawBits(bool IsLittleEndian, unsigned DstEltSizeInBits,
+                            SmallVectorImpl<APInt> &DstBitElements,
+                            ArrayRef<APInt> SrcBitElements,
+                            BitVector &DstUndefElements,
+                            const BitVector &SrcUndefElements);
+
   static bool classof(const SDNode *N) {
     return N->getOpcode() == ISD::BUILD_VECTOR;
   }
index 0d5513d9e39ef496d846ef39701e4fb0dd00cc00..0fa9abf612974dafc69015e0707da7fb432e269e 100644 (file)
@@ -10830,59 +10830,83 @@ bool BuildVectorSDNode::getConstantRawBits(
   assert(((NumSrcOps * SrcEltSizeInBits) % DstEltSizeInBits) == 0 &&
          "Invalid bitcast scale");
 
+  // Extract raw src bits.
+  SmallVector<APInt> SrcBitElements(NumSrcOps,
+                                    APInt::getNullValue(SrcEltSizeInBits));
+  BitVector SrcUndeElements(NumSrcOps, false);
+
+  for (unsigned I = 0; I != NumSrcOps; ++I) {
+    SDValue Op = getOperand(I);
+    if (Op.isUndef()) {
+      SrcUndeElements.set(I);
+      continue;
+    }
+    auto *CInt = dyn_cast<ConstantSDNode>(Op);
+    auto *CFP = dyn_cast<ConstantFPSDNode>(Op);
+    assert((CInt || CFP) && "Unknown constant");
+    SrcBitElements[I] =
+        CInt ? CInt->getAPIntValue().truncOrSelf(SrcEltSizeInBits)
+             : CFP->getValueAPF().bitcastToAPInt();
+  }
+
+  // Recast to dst width.
+  recastRawBits(IsLittleEndian, DstEltSizeInBits, RawBitElements,
+                SrcBitElements, UndefElements, SrcUndeElements);
+  return true;
+}
+
+void BuildVectorSDNode::recastRawBits(bool IsLittleEndian,
+                                      unsigned DstEltSizeInBits,
+                                      SmallVectorImpl<APInt> &DstBitElements,
+                                      ArrayRef<APInt> SrcBitElements,
+                                      BitVector &DstUndefElements,
+                                      const BitVector &SrcUndefElements) {
+  unsigned NumSrcOps = SrcBitElements.size();
+  unsigned SrcEltSizeInBits = SrcBitElements[0].getBitWidth();
+  assert(((NumSrcOps * SrcEltSizeInBits) % DstEltSizeInBits) == 0 &&
+         "Invalid bitcast scale");
+  assert(NumSrcOps == SrcUndefElements.size() &&
+         "Vector size mismatch");
+
   unsigned NumDstOps = (NumSrcOps * SrcEltSizeInBits) / DstEltSizeInBits;
-  UndefElements.clear();
-  UndefElements.resize(NumDstOps, false);
-  RawBitElements.assign(NumDstOps, APInt::getNullValue(DstEltSizeInBits));
+  DstUndefElements.clear();
+  DstUndefElements.resize(NumDstOps, false);
+  DstBitElements.assign(NumDstOps, APInt::getNullValue(DstEltSizeInBits));
 
   // Concatenate src elements constant bits together into dst element.
   if (SrcEltSizeInBits <= DstEltSizeInBits) {
     unsigned Scale = DstEltSizeInBits / SrcEltSizeInBits;
     for (unsigned I = 0; I != NumDstOps; ++I) {
-      UndefElements.set(I);
-      APInt &RawBits = RawBitElements[I];
+      DstUndefElements.set(I);
+      APInt &DstBits = DstBitElements[I];
       for (unsigned J = 0; J != Scale; ++J) {
         unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1));
-        SDValue Op = getOperand(Idx);
-        if (Op.isUndef())
+        if (SrcUndefElements[Idx])
           continue;
-        UndefElements.reset(I);
-        auto *CInt = dyn_cast<ConstantSDNode>(Op);
-        auto *CFP = dyn_cast<ConstantFPSDNode>(Op);
-        assert((CInt || CFP) && "Unknown constant");
-        APInt EltBits =
-            CInt ? CInt->getAPIntValue().truncOrSelf(SrcEltSizeInBits)
-                 : CFP->getValueAPF().bitcastToAPInt();
-        assert(EltBits.getBitWidth() == SrcEltSizeInBits &&
+        DstUndefElements.reset(I);
+        const APInt &SrcBits = SrcBitElements[Idx];
+        assert(SrcBits.getBitWidth() == SrcEltSizeInBits &&
                "Illegal constant bitwidths");
-        RawBits.insertBits(EltBits, J * SrcEltSizeInBits);
+        DstBits.insertBits(SrcBits, J * SrcEltSizeInBits);
       }
     }
-    return true;
+    return;
   }
 
   // Split src element constant bits into dst elements.
   unsigned Scale = SrcEltSizeInBits / DstEltSizeInBits;
   for (unsigned I = 0; I != NumSrcOps; ++I) {
-    SDValue Op = getOperand(I);
-    if (Op.isUndef()) {
-      UndefElements.set(I * Scale, (I + 1) * Scale);
+    if (SrcUndefElements[I]) {
+      DstUndefElements.set(I * Scale, (I + 1) * Scale);
       continue;
     }
-    auto *CInt = dyn_cast<ConstantSDNode>(Op);
-    auto *CFP = dyn_cast<ConstantFPSDNode>(Op);
-    assert((CInt || CFP) && "Unknown constant");
-    APInt EltBits =
-        CInt ? CInt->getAPIntValue() : CFP->getValueAPF().bitcastToAPInt();
-
+    const APInt &SrcBits = SrcBitElements[I];
     for (unsigned J = 0; J != Scale; ++J) {
       unsigned Idx = (I * Scale) + (IsLittleEndian ? J : (Scale - J - 1));
-      APInt &RawBits = RawBitElements[Idx];
-      RawBits = EltBits.extractBits(DstEltSizeInBits, J * DstEltSizeInBits);
+      APInt &DstBits = DstBitElements[Idx];
+      DstBits = SrcBits.extractBits(DstEltSizeInBits, J * DstEltSizeInBits);
     }
   }
-
-  return true;
 }
 
 bool BuildVectorSDNode::isConstant() const {