[DAGCombine] Refactor DAGCombiner::ReduceLoadWidth. NFCI
authorBjorn Pettersson <bjorn.a.pettersson@ericsson.com>
Tue, 11 Jan 2022 21:59:18 +0000 (22:59 +0100)
committerBjorn Pettersson <bjorn.a.pettersson@ericsson.com>
Sun, 16 Jan 2022 19:24:52 +0000 (20:24 +0100)
Update code comments in DAGCombiner::ReduceLoadWidth and refactor
the handling of SRL a bit. The refactoring is done with the intent
of adding support for folding away SRA by using SEXTLOAD in a
follow-up patch.

The function is also renamed as DAGCombiner::reduceLoadWidth.

Differential Revision: https://reviews.llvm.org/D117104

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

index 7e474db..7c684bd 100644 (file)
@@ -593,7 +593,7 @@ namespace {
     SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
     SDValue MatchLoadCombine(SDNode *N);
     SDValue mergeTruncStores(StoreSDNode *N);
-    SDValue ReduceLoadWidth(SDNode *N);
+    SDValue reduceLoadWidth(SDNode *N);
     SDValue ReduceLoadOpStoreWidth(SDNode *N);
     SDValue splitMergedValStore(StoreSDNode *ST);
     SDValue TransformFPLoadStorePair(SDNode *N);
@@ -5624,7 +5624,7 @@ bool DAGCombiner::BackwardsPropagateMask(SDNode *N) {
       if (And.getOpcode() == ISD ::AND)
         And = SDValue(
             DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0);
-      SDValue NewLoad = ReduceLoadWidth(And.getNode());
+      SDValue NewLoad = reduceLoadWidth(And.getNode());
       assert(NewLoad &&
              "Shouldn't be masking the load if it can't be narrowed");
       CombineTo(Load, NewLoad, NewLoad.getValue(1));
@@ -6024,7 +6024,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
   if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
                                 (N0.getOpcode() == ISD::ANY_EXTEND &&
                                  N0.getOperand(0).getOpcode() == ISD::LOAD))) {
-    if (SDValue Res = ReduceLoadWidth(N)) {
+    if (SDValue Res = reduceLoadWidth(N)) {
       LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
         ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
       AddToWorklist(N);
@@ -9140,7 +9140,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
       return NewSRL;
 
   // Attempt to convert a srl of a load into a narrower zero-extending load.
-  if (SDValue NarrowLoad = ReduceLoadWidth(N))
+  if (SDValue NarrowLoad = reduceLoadWidth(N))
     return NarrowLoad;
 
   // Here is a common situation. We want to optimize:
@@ -11357,7 +11357,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
   if (N0.getOpcode() == ISD::TRUNCATE) {
     // fold (sext (truncate (load x))) -> (sext (smaller load x))
     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
-    if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
+    if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
       SDNode *oye = N0.getOperand(0).getNode();
       if (NarrowLoad.getNode() != N0.getNode()) {
         CombineTo(N0.getNode(), NarrowLoad);
@@ -11621,7 +11621,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
   if (N0.getOpcode() == ISD::TRUNCATE) {
     // fold (zext (truncate (load x))) -> (zext (smaller load x))
     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
-    if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
+    if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
       SDNode *oye = N0.getOperand(0).getNode();
       if (NarrowLoad.getNode() != N0.getNode()) {
         CombineTo(N0.getNode(), NarrowLoad);
@@ -11864,7 +11864,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
   // fold (aext (truncate (load x))) -> (aext (smaller load x))
   // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n)))
   if (N0.getOpcode() == ISD::TRUNCATE) {
-    if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
+    if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) {
       SDNode *oye = N0.getOperand(0).getNode();
       if (NarrowLoad.getNode() != N0.getNode()) {
         CombineTo(N0.getNode(), NarrowLoad);
@@ -12095,13 +12095,10 @@ SDValue DAGCombiner::visitAssertAlign(SDNode *N) {
   return SDValue();
 }
 
-/// If the result of a wider load is shifted to right of N  bits and then
-/// truncated to a narrower type and where N is a multiple of number of bits of
-/// the narrower type, transform it to a narrower load from address + N / num of
-/// bits of new type. Also narrow the load if the result is masked with an AND
-/// to effectively produce a smaller type. If the result is to be extended, also
-/// fold the extension to form a extending load.
-SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
+/// If the result of a load is shifted/masked/truncated to an effectively
+/// narrower type, try to transform the load to a narrower type and/or
+/// use an extending load.
+SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
   unsigned Opc = N->getOpcode();
 
   ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
@@ -12113,7 +12110,14 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
   if (VT.isVector())
     return SDValue();
 
+  // The ShAmt variable is used to indicate that we've consumed a right
+  // shift. I.e. we want to narrow the width of the load by skipping to load the
+  // ShAmt least significant bits.
   unsigned ShAmt = 0;
+  // A special case is when the least significant bits from the load are masked
+  // away, but using an AND rather than a right shift. HasShiftedOffset is used
+  // to indicate that the narrowed load should be left-shifted ShAmt bits to get
+  // the result.
   bool HasShiftedOffset = false;
   // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then
   // extended to VT.
@@ -12122,23 +12126,29 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
     ExtVT = cast<VTSDNode>(N->getOperand(1))->getVT();
   } else if (Opc == ISD::SRL) {
     // Another special-case: SRL is basically zero-extending a narrower value,
-    // or it maybe shifting a higher subword, half or byte into the lowest
+    // or it may be shifting a higher subword, half or byte into the lowest
     // bits.
-    ExtType = ISD::ZEXTLOAD;
-    N0 = SDValue(N, 0);
 
-    auto *LN0 = dyn_cast<LoadSDNode>(N0.getOperand(0));
-    auto *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1));
-    if (!N01 || !LN0)
+    // Only handle shift with constant shift amount, and the shiftee must be a
+    // load.
+    auto *LN = dyn_cast<LoadSDNode>(N0);
+    auto *N1C = dyn_cast<ConstantSDNode>(N->getOperand(1));
+    if (!N1C || !LN)
+      return SDValue();
+    // If the shift amount is larger than the memory type then we're not
+    // accessing any of the loaded bytes.
+    ShAmt = N1C->getZExtValue();
+    uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits();
+    if (MemoryWidth <= ShAmt)
+      return SDValue();
+    // Attempt to fold away the SRL by using ZEXTLOAD.
+    ExtType = ISD::ZEXTLOAD;
+    ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt);
+    // If original load is a SEXTLOAD then we can't simply replace it by a
+    // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD
+    // followed by a ZEXT, but that is not handled at the moment).
+    if (LN->getExtensionType() == ISD::SEXTLOAD)
       return SDValue();
-
-    uint64_t ShiftAmt = N01->getZExtValue();
-    uint64_t MemoryWidth = LN0->getMemoryVT().getScalarSizeInBits();
-    if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt)
-      ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt);
-    else
-      ExtVT = EVT::getIntegerVT(*DAG.getContext(),
-                                VT.getScalarSizeInBits() - ShiftAmt);
   } else if (Opc == ISD::AND) {
     // An AND with a constant mask is the same as a truncate + zero-extend.
     auto AndC = dyn_cast<ConstantSDNode>(N->getOperand(1));
@@ -12161,55 +12171,73 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
     ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits);
   }
 
-  if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) {
-    SDValue SRL = N0;
-    if (auto *ConstShift = dyn_cast<ConstantSDNode>(SRL.getOperand(1))) {
-      ShAmt = ConstShift->getZExtValue();
-      unsigned EVTBits = ExtVT.getScalarSizeInBits();
-      // Is the shift amount a multiple of size of VT?
-      if ((ShAmt & (EVTBits-1)) == 0) {
-        N0 = N0.getOperand(0);
-        // Is the load width a multiple of size of VT?
-        if ((N0.getScalarValueSizeInBits() & (EVTBits - 1)) != 0)
-          return SDValue();
-      }
+  // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing
+  // a right shift. Here we redo some of those checks, to possibly adjust the
+  // ExtVT even further based on "a masking AND". We could also end up here for
+  // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks
+  // need to be done here as well.
+  if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) {
+    SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0;
+    // Bail out when the SRL has more than one use. This is done for historical
+    // (undocumented) reasons. Maybe intent was to guard the AND-masking below
+    // check below? And maybe it could be non-profitable to do the transform in
+    // case the SRL has multiple uses and we get here with Opc!=ISD::SRL?
+    // FIXME: Can't we just skip this check for the Opc==ISD::SRL case.
+    if (!SRL.hasOneUse())
+      return SDValue();
+
+    // Only handle shift with constant shift amount, and the shiftee must be a
+    // load.
+    auto *LN = dyn_cast<LoadSDNode>(SRL.getOperand(0));
+    auto *SRL1C = dyn_cast<ConstantSDNode>(SRL.getOperand(1));
+    if (!SRL1C || !LN)
+      return SDValue();
 
-      // At this point, we must have a load or else we can't do the transform.
-      auto *LN0 = dyn_cast<LoadSDNode>(N0);
-      if (!LN0) return SDValue();
+    // If the shift amount is larger than the input type then we're not
+    // accessing any of the loaded bytes.  If the load was a zextload/extload
+    // then the result of the shift+trunc is zero/undef (handled elsewhere).
+    ShAmt = SRL1C->getZExtValue();
+    if (ShAmt >= LN->getMemoryVT().getSizeInBits())
+      return SDValue();
 
-      // Because a SRL must be assumed to *need* to zero-extend the high bits
-      // (as opposed to anyext the high bits), we can't combine the zextload
-      // lowering of SRL and an sextload.
-      if (LN0->getExtensionType() == ISD::SEXTLOAD)
-        return SDValue();
+    // Because a SRL must be assumed to *need* to zero-extend the high bits
+    // (as opposed to anyext the high bits), we can't combine the zextload
+    // lowering of SRL and an sextload.
+    if (LN->getExtensionType() == ISD::SEXTLOAD)
+      return SDValue();
 
-      // If the shift amount is larger than the input type then we're not
-      // accessing any of the loaded bytes.  If the load was a zextload/extload
-      // then the result of the shift+trunc is zero/undef (handled elsewhere).
-      if (ShAmt >= LN0->getMemoryVT().getSizeInBits())
-        return SDValue();
+    unsigned ExtVTBits = ExtVT.getScalarSizeInBits();
+    // Is the shift amount a multiple of size of ExtVT?
+    if ((ShAmt & (ExtVTBits - 1)) != 0)
+      return SDValue();
+    // Is the load width a multiple of size of ExtVT?
+    if ((SRL.getScalarValueSizeInBits() & (ExtVTBits - 1)) != 0)
+      return SDValue();
 
-      // If the SRL is only used by a masking AND, we may be able to adjust
-      // the ExtVT to make the AND redundant.
-      SDNode *Mask = *(SRL->use_begin());
-      if (Mask->getOpcode() == ISD::AND &&
-          isa<ConstantSDNode>(Mask->getOperand(1))) {
-        const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
-        if (ShiftMask.isMask()) {
-          EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
-                                           ShiftMask.countTrailingOnes());
-          // If the mask is smaller, recompute the type.
-          if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) &&
-              TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT))
-            ExtVT = MaskedVT;
-        }
+    // If the SRL is only used by a masking AND, we may be able to adjust
+    // the ExtVT to make the AND redundant.
+    SDNode *Mask = *(SRL->use_begin());
+    if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND &&
+        isa<ConstantSDNode>(Mask->getOperand(1))) {
+      const APInt& ShiftMask = Mask->getConstantOperandAPInt(1);
+      if (ShiftMask.isMask()) {
+        EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(),
+                                         ShiftMask.countTrailingOnes());
+        // If the mask is smaller, recompute the type.
+        if ((ExtVTBits > MaskedVT.getScalarSizeInBits()) &&
+            TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT))
+          ExtVT = MaskedVT;
       }
     }
+
+    N0 = SRL.getOperand(0);
   }
 
-  // If the load is shifted left (and the result isn't shifted back right),
-  // we can fold the truncate through the shift.
+  // If the load is shifted left (and the result isn't shifted back right), we
+  // can fold a truncate through the shift. The typical scenario is that N
+  // points at a TRUNCATE here so the attempted fold is:
+  //   (truncate (shl (load x), c))) -> (shl (narrow load x), c)
+  // ShLeftAmt will indicate how much a narrowed load should be shifted left.
   unsigned ShLeftAmt = 0;
   if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
       ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
@@ -12237,12 +12265,12 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
     return LVTStoreBits - EVTStoreBits - ShAmt;
   };
 
-  // For big endian targets, we need to adjust the offset to the pointer to
-  // load the correct bytes.
-  if (DAG.getDataLayout().isBigEndian())
-    ShAmt = AdjustBigEndianShift(ShAmt);
+  // We need to adjust the pointer to the load by ShAmt bits in order to load
+  // the correct bytes.
+  unsigned PtrAdjustmentInBits =
+      DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt;
 
-  uint64_t PtrOff = ShAmt / 8;
+  uint64_t PtrOff = PtrAdjustmentInBits / 8;
   Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff);
   SDLoc DL(LN0);
   // The original load itself didn't wrap, so an offset within it doesn't.
@@ -12285,11 +12313,6 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) {
   }
 
   if (HasShiftedOffset) {
-    // Recalculate the shift amount after it has been altered to calculate
-    // the offset.
-    if (DAG.getDataLayout().isBigEndian())
-      ShAmt = AdjustBigEndianShift(ShAmt);
-
     // We're using a shifted mask, so the load now has an offset. This means
     // that data has been loaded into the lower bytes than it would have been
     // before, so we need to shl the loaded data into the correct position in the
@@ -12382,7 +12405,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
 
   // fold (sext_in_reg (load x)) -> (smaller sextload x)
   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
-  if (SDValue NarrowLoad = ReduceLoadWidth(N))
+  if (SDValue NarrowLoad = reduceLoadWidth(N))
     return NarrowLoad;
 
   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
@@ -12669,7 +12692,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   // fold (truncate (load x)) -> (smaller load x)
   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
-    if (SDValue Reduced = ReduceLoadWidth(N))
+    if (SDValue Reduced = reduceLoadWidth(N))
       return Reduced;
 
     // Handle the case where the load remains an extending load even