[SelectionDAG] Remove duplicate "is scaled" information from gather/scatter SDNodes.
authorPaul Walker <paul.walker@arm.com>
Fri, 8 Apr 2022 11:59:21 +0000 (12:59 +0100)
committerPaul Walker <paul.walker@arm.com>
Mon, 16 May 2022 19:47:52 +0000 (20:47 +0100)
During early gather/scatter enablement two different approaches
were taken to represent scaled indices:

* A Scale operand whereby byte_offsets = Index * Scale
* An IndexType whereby byte_offsets = Index * sizeof(MemVT.ElementType)

Having multiple representations is bad as shown by this patch which
fixes instances where the two are out of sync. The dedicated scale
operand is more flexible and pervasive so this patch removes the
UNSCALED values from IndexType. This means all indices are scaled
but the scale can be one, hence unscaled. SDNodes now use the scale
operand to answer the "isScaledIndex" question.

I toyed with the idea of keeping the UNSCALED enums and helper
functions but because they will have no uses and force SDNodes to
validate the set of supported values I figured it's best to remove
them. We can re-add them if there's a real need. For similar
reasons I've kept the IndexType enum when a bool could be used as I
think being explicitly looks better.

Depends On D123347

Differential Revision: https://reviews.llvm.org/D123381

llvm/include/llvm/CodeGen/ISDOpcodes.h
llvm/include/llvm/CodeGen/SelectionDAGNodes.h
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.cpp

index 6ceb309..eae6bb9 100644 (file)
@@ -1358,37 +1358,17 @@ static const int LAST_INDEXED_MODE = POST_DEC + 1;
 /// MemIndexType enum - This enum defines how to interpret MGATHER/SCATTER's
 /// index parameter when calculating addresses.
 ///
-/// SIGNED_SCALED     Addr = Base + ((signed)Index * sizeof(element))
-/// SIGNED_UNSCALED   Addr = Base + (signed)Index
-/// UNSIGNED_SCALED   Addr = Base + ((unsigned)Index * sizeof(element))
-/// UNSIGNED_UNSCALED Addr = Base + (unsigned)Index
-enum MemIndexType {
-  SIGNED_SCALED = 0,
-  SIGNED_UNSCALED,
-  UNSIGNED_SCALED,
-  UNSIGNED_UNSCALED
-};
+/// SIGNED_SCALED     Addr = Base + ((signed)Index * Scale)
+/// UNSIGNED_SCALED   Addr = Base + ((unsigned)Index * Scale)
+///
+/// NOTE: The value of Scale is typically only known to the node owning the
+/// IndexType, with a value of 1 the equivalent of being unscaled.
+enum MemIndexType { SIGNED_SCALED = 0, UNSIGNED_SCALED };
 
-static const int LAST_MEM_INDEX_TYPE = UNSIGNED_UNSCALED + 1;
-
-inline bool isIndexTypeScaled(MemIndexType IndexType) {
-  return IndexType == SIGNED_SCALED || IndexType == UNSIGNED_SCALED;
-}
+static const int LAST_MEM_INDEX_TYPE = UNSIGNED_SCALED + 1;
 
 inline bool isIndexTypeSigned(MemIndexType IndexType) {
-  return IndexType == SIGNED_SCALED || IndexType == SIGNED_UNSCALED;
-}
-
-inline MemIndexType getSignedIndexType(MemIndexType IndexType) {
-  return isIndexTypeScaled(IndexType) ? SIGNED_SCALED : SIGNED_UNSCALED;
-}
-
-inline MemIndexType getUnsignedIndexType(MemIndexType IndexType) {
-  return isIndexTypeScaled(IndexType) ? UNSIGNED_SCALED : UNSIGNED_UNSCALED;
-}
-
-inline MemIndexType getUnscaledIndexType(MemIndexType IndexType) {
-  return isIndexTypeSigned(IndexType) ? SIGNED_UNSCALED : UNSIGNED_UNSCALED;
+  return IndexType == SIGNED_SCALED;
 }
 
 //===--------------------------------------------------------------------===//
index 659cfa2..5974f13 100644 (file)
@@ -2702,7 +2702,9 @@ public:
   ISD::MemIndexType getIndexType() const {
     return static_cast<ISD::MemIndexType>(LSBaseSDNodeBits.AddressingMode);
   }
-  bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); }
+  bool isIndexScaled() const {
+    return !cast<ConstantSDNode>(getScale())->isOne();
+  }
   bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); }
 
   // In the both nodes address is Op1, mask is Op2:
@@ -2784,7 +2786,9 @@ public:
   ISD::MemIndexType getIndexType() const {
     return static_cast<ISD::MemIndexType>(LSBaseSDNodeBits.AddressingMode);
   }
-  bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); }
+  bool isIndexScaled() const {
+    return !cast<ConstantSDNode>(getScale())->isOne();
+  }
   bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); }
 
   // In the both nodes address is Op1, mask is Op2:
index 9bcecad..0c619a4 100644 (file)
@@ -4891,10 +4891,6 @@ public:
   // combiner can fold the new nodes.
   SDValue lowerCmpEqZeroToCtlzSrl(SDValue Op, SelectionDAG &DAG) const;
 
-  /// Give targets the chance to reduce the number of distinct addresing modes.
-  ISD::MemIndexType getCanonicalIndexType(ISD::MemIndexType IndexType,
-                                          EVT MemVT, SDValue Offsets) const;
-
 private:
   SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
                            const SDLoc &DL, DAGCombinerInfo &DCI) const;
index 9d7f4de..1bde36a 100644 (file)
@@ -10446,12 +10446,12 @@ bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType,
   if (Index.getOpcode() == ISD::ZERO_EXTEND) {
     SDValue Op = Index.getOperand(0);
     if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
-      IndexType = ISD::getUnsignedIndexType(IndexType);
+      IndexType = ISD::UNSIGNED_SCALED;
       Index = Op;
       return true;
     }
     if (ISD::isIndexTypeSigned(IndexType)) {
-      IndexType = ISD::getUnsignedIndexType(IndexType);
+      IndexType = ISD::UNSIGNED_SCALED;
       return true;
     }
   }
index d654688..1c237c2 100644 (file)
@@ -8605,7 +8605,6 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT MemVT, const SDLoc &dl,
     return SDValue(E, 0);
   }
 
-  IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]);
   auto *N = newSDNode<MaskedGatherSDNode>(dl.getIROrder(), dl.getDebugLoc(),
                                           VTs, MemVT, MMO, IndexType, ExtTy);
   createOperands(N, Ops);
@@ -8653,7 +8652,6 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT MemVT, const SDLoc &dl,
     return SDValue(E, 0);
   }
 
-  IndexType = TLI->getCanonicalIndexType(IndexType, MemVT, Ops[4]);
   auto *N = newSDNode<MaskedScatterSDNode>(dl.getIROrder(), dl.getDebugLoc(),
                                            VTs, MemVT, MMO, IndexType, IsTrunc);
   createOperands(N, Ops);
index 170fa25..2e0d050 100644 (file)
@@ -4444,7 +4444,7 @@ void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
   if (!UniformBase) {
     Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
     Index = getValue(Ptr);
-    IndexType = ISD::SIGNED_UNSCALED;
+    IndexType = ISD::SIGNED_SCALED;
     Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
 
@@ -4552,7 +4552,7 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
   if (!UniformBase) {
     Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout()));
     Index = getValue(Ptr);
-    IndexType = ISD::SIGNED_UNSCALED;
+    IndexType = ISD::SIGNED_SCALED;
     Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout()));
   }
 
@@ -7416,7 +7416,7 @@ void SelectionDAGBuilder::visitVPLoadGather(const VPIntrinsic &VPIntrin, EVT VT,
     if (!UniformBase) {
       Base = DAG.getConstant(0, DL, TLI.getPointerTy(DAG.getDataLayout()));
       Index = getValue(PtrOperand);
-      IndexType = ISD::SIGNED_UNSCALED;
+      IndexType = ISD::SIGNED_SCALED;
       Scale =
           DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout()));
     }
@@ -7473,7 +7473,7 @@ void SelectionDAGBuilder::visitVPStoreScatter(const VPIntrinsic &VPIntrin,
     if (!UniformBase) {
       Base = DAG.getConstant(0, DL, TLI.getPointerTy(DAG.getDataLayout()));
       Index = getValue(PtrOperand);
-      IndexType = ISD::SIGNED_UNSCALED;
+      IndexType = ISD::SIGNED_SCALED;
       Scale =
           DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout()));
     }
index 72d2216..f3a0936 100644 (file)
@@ -8631,18 +8631,6 @@ SDValue TargetLowering::lowerCmpEqZeroToCtlzSrl(SDValue Op,
   return SDValue();
 }
 
-// Convert redundant addressing modes (e.g. scaling is redundant
-// when accessing bytes).
-ISD::MemIndexType
-TargetLowering::getCanonicalIndexType(ISD::MemIndexType IndexType, EVT MemVT,
-                                      SDValue Offsets) const {
-  // Scaling is unimportant for bytes, canonicalize to unscaled.
-  if (ISD::isIndexTypeScaled(IndexType) && MemVT.getScalarType() == MVT::i8)
-    return ISD::getUnscaledIndexType(IndexType);
-
-  return IndexType;
-}
-
 SDValue TargetLowering::expandIntMINMAX(SDNode *Node, SelectionDAG &DAG) const {
   SDValue Op0 = Node->getOperand(0);
   SDValue Op1 = Node->getOperand(1);
index ae920d3..95d7269 100644 (file)
@@ -4695,7 +4695,6 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
     Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
 
     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
-    IndexType = getUnscaledIndexType(IndexType);
     return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
                                MGT->getMemOperand(), IndexType, ExtType);
   }
@@ -4794,7 +4793,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
     Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
 
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
-    IndexType = getUnscaledIndexType(IndexType);
     return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
                                 MSC->getMemOperand(), IndexType,
                                 MSC->isTruncatingStore());
index f64e160..72af5c4 100644 (file)
@@ -8851,40 +8851,41 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                           DL, IndexVT, Index);
     }
 
-    unsigned Scale = cast<ConstantSDNode>(ScaleOp)->getZExtValue();
-    if (IsIndexScaled && Scale != 1) {
-      // Manually scale the indices by the element size.
+    if (IsIndexScaled) {
+      // Manually scale the indices.
       // TODO: Sanitize the scale operand here?
       // TODO: For VP nodes, should we use VP_SHL here?
+      unsigned Scale = cast<ConstantSDNode>(ScaleOp)->getZExtValue();
       assert(isPowerOf2_32(Scale) && "Expecting power-of-two types");
       SDValue SplatScale = DAG.getConstant(Log2_32(Scale), DL, IndexVT);
       Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index, SplatScale);
+      ScaleOp = DAG.getTargetConstant(1, DL, ScaleOp.getValueType());
     }
 
-    ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_UNSCALED;
+    ISD::MemIndexType NewIndexTy = ISD::UNSIGNED_SCALED;
     if (const auto *VPGN = dyn_cast<VPGatherSDNode>(N))
       return DAG.getGatherVP(N->getVTList(), VPGN->getMemoryVT(), DL,
                              {VPGN->getChain(), VPGN->getBasePtr(), Index,
-                              VPGN->getScale(), VPGN->getMask(),
+                              ScaleOp, VPGN->getMask(),
                               VPGN->getVectorLength()},
                              VPGN->getMemOperand(), NewIndexTy);
     if (const auto *VPSN = dyn_cast<VPScatterSDNode>(N))
       return DAG.getScatterVP(N->getVTList(), VPSN->getMemoryVT(), DL,
                               {VPSN->getChain(), VPSN->getValue(),
-                               VPSN->getBasePtr(), Index, VPSN->getScale(),
+                               VPSN->getBasePtr(), Index, ScaleOp,
                                VPSN->getMask(), VPSN->getVectorLength()},
                               VPSN->getMemOperand(), NewIndexTy);
     if (const auto *MGN = dyn_cast<MaskedGatherSDNode>(N))
       return DAG.getMaskedGather(
           N->getVTList(), MGN->getMemoryVT(), DL,
           {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
-           MGN->getBasePtr(), Index, MGN->getScale()},
+           MGN->getBasePtr(), Index, ScaleOp},
           MGN->getMemOperand(), NewIndexTy, MGN->getExtensionType());
     const auto *MSN = cast<MaskedScatterSDNode>(N);
     return DAG.getMaskedScatter(
         N->getVTList(), MSN->getMemoryVT(), DL,
         {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
-         Index, MSN->getScale()},
+         Index, ScaleOp},
         MSN->getMemOperand(), NewIndexTy, MSN->isTruncatingStore());
   }
   case RISCVISD::SRA_VL: