[X86] Stop using UpdateNodeOperands in combineGatherScatter. Create new nodes like...
authorCraig Topper <craig.topper@intel.com>
Sat, 28 Sep 2019 01:08:46 +0000 (01:08 +0000)
committerCraig Topper <craig.topper@intel.com>
Sat, 28 Sep 2019 01:08:46 +0000 (01:08 +0000)
Creating new nodes is what we usually do. Have to explicitly
check that we don't update to an existing node and having
to manually manage the worklist is unusual.

We can probably add a helper function to reduce the duplication
of having to check if we should create a gather or scatter, but
I wanted to just get the simple thing done.

llvm-svn: 373137

llvm/lib/Target/X86/X86ISelLowering.cpp

index 5741b80..521fc3c 100644 (file)
@@ -43381,26 +43381,36 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
                                     TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
+  auto *GorS = cast<MaskedGatherScatterSDNode>(N);
+  SDValue Chain = GorS->getChain();
+  SDValue Index = GorS->getIndex();
+  SDValue Mask = GorS->getMask();
+  SDValue Base = GorS->getBasePtr();
+  SDValue Scale = GorS->getScale();
 
   if (DCI.isBeforeLegalizeOps()) {
-    SDValue Index = N->getOperand(4);
     // Remove any sign extends from 32 or smaller to larger than 32.
     // Only do this before LegalizeOps in case we need the sign extend for
     // legalization.
-    if (Index.getOpcode() == ISD::SIGN_EXTEND) {
-      if (Index.getScalarValueSizeInBits() > 32 &&
-          Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
-        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-        NewOps[4] = Index.getOperand(0);
-        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-        if (Res == N) {
-          // The original sign extend has less users, add back to worklist in
-          // case it needs to be removed
-          DCI.AddToWorklist(Index.getNode());
-          DCI.AddToWorklist(N);
-        }
-        return SDValue(Res, 0);
-      }
+    if (Index.getOpcode() == ISD::SIGN_EXTEND &&
+        Index.getScalarValueSizeInBits() > 32 &&
+        Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
+      Index = Index.getOperand(0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
+      }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
 
     // Make sure the index is either i32 or i64
@@ -43410,36 +43420,49 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
                                    Index.getValueType().getVectorNumElements());
       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
-      SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-      NewOps[4] = Index;
-      SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-      if (Res == N)
-        DCI.AddToWorklist(N);
-      return SDValue(Res, 0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
+      }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
 
     // Try to remove zero extends from 32->64 if we know the sign bit of
     // the input is zero.
     if (Index.getOpcode() == ISD::ZERO_EXTEND &&
         Index.getScalarValueSizeInBits() == 64 &&
-        Index.getOperand(0).getScalarValueSizeInBits() == 32) {
-      if (DAG.SignBitIsZero(Index.getOperand(0))) {
-        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-        NewOps[4] = Index.getOperand(0);
-        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-        if (Res == N) {
-          // The original sign extend has less users, add back to worklist in
-          // case it needs to be removed
-          DCI.AddToWorklist(Index.getNode());
-          DCI.AddToWorklist(N);
-        }
-        return SDValue(Res, 0);
-      }
+        Index.getOperand(0).getScalarValueSizeInBits() == 32 &&
+        DAG.SignBitIsZero(Index.getOperand(0))) {
+      Index = Index.getOperand(0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
+      }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
   }
 
   // With vector masks we only demand the upper bit of the mask.
-  SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));