[SelectionDAG] Properly copy ExtraInfo on RAUW
authorMarco Elver <elver@google.com>
Tue, 6 Sep 2022 13:48:58 +0000 (15:48 +0200)
committerMarco Elver <elver@google.com>
Tue, 6 Sep 2022 14:32:50 +0000 (16:32 +0200)
During SelectionDAG legalization SDNodes with associated extra info may
be replaced with a new SDNode. Preserve associated extra info on
ReplaceAllUsesWith and remove entries in DeallocateNode.

Reviewed By: vitalybuka

Differential Revision: https://reviews.llvm.org/D130881

llvm/include/llvm/CodeGen/SelectionDAG.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

index a481881..ba7fc62 100644 (file)
@@ -2195,6 +2195,9 @@ public:
     return I != SDEI.end() ? I->second.NoMerge : false;
   }
 
+  /// Copy extra info associated with one node to another.
+  void copyExtraInfo(SDNode *From, SDNode *To);
+
   /// Return the current function's default denormal handling kind for the given
   /// floating point type.
   DenormalMode getDenormalMode(EVT VT) const {
index 3cb10c8..97b14b4 100644 (file)
@@ -1047,6 +1047,9 @@ void SelectionDAG::DeallocateNode(SDNode *N) {
   // If any of the SDDbgValue nodes refer to this SDNode, invalidate
   // them and forget about that node.
   DbgInfo->erase(N);
+
+  // Invalidate extra info.
+  SDEI.erase(N);
 }
 
 #ifndef NDEBUG
@@ -10177,6 +10180,8 @@ void SelectionDAG::ReplaceAllUsesWith(SDValue FromN, SDValue To) {
 
   // Preserve Debug Values
   transferDbgValues(FromN, To);
+  // Preserve extra info.
+  copyExtraInfo(From, To.getNode());
 
   // Iterate over all the existing uses of From. New uses will be added
   // to the beginning of the use list, which we avoid visiting.
@@ -10238,6 +10243,8 @@ void SelectionDAG::ReplaceAllUsesWith(SDNode *From, SDNode *To) {
       assert((i < To->getNumValues()) && "Invalid To location");
       transferDbgValues(SDValue(From, i), SDValue(To, i));
     }
+  // Preserve extra info.
+  copyExtraInfo(From, To);
 
   // Iterate over just the existing users of From. See the comments in
   // the ReplaceAllUsesWith above.
@@ -10280,9 +10287,12 @@ void SelectionDAG::ReplaceAllUsesWith(SDNode *From, const SDValue *To) {
   if (From->getNumValues() == 1)  // Handle the simple case efficiently.
     return ReplaceAllUsesWith(SDValue(From, 0), To[0]);
 
-  // Preserve Debug Info.
-  for (unsigned i = 0, e = From->getNumValues(); i != e; ++i)
+  for (unsigned i = 0, e = From->getNumValues(); i != e; ++i) {
+    // Preserve Debug Info.
     transferDbgValues(SDValue(From, i), To[i]);
+    // Preserve extra info.
+    copyExtraInfo(From, To[i].getNode());
+  }
 
   // Iterate over just the existing users of From. See the comments in
   // the ReplaceAllUsesWith above.
@@ -10335,6 +10345,7 @@ void SelectionDAG::ReplaceAllUsesOfValueWith(SDValue From, SDValue To){
 
   // Preserve Debug Info.
   transferDbgValues(From, To);
+  copyExtraInfo(From.getNode(), To.getNode());
 
   // Iterate over just the existing users of From. See the comments in
   // the ReplaceAllUsesWith above.
@@ -10488,6 +10499,7 @@ void SelectionDAG::ReplaceAllUsesOfValuesWith(const SDValue *From,
     return ReplaceAllUsesOfValueWith(*From, *To);
 
   transferDbgValues(*From, *To);
+  copyExtraInfo(From->getNode(), To->getNode());
 
   // Read up all the uses and make records of them. This helps
   // processing new uses that are introduced during the
@@ -11933,6 +11945,14 @@ SDValue SelectionDAG::getNeutralElement(unsigned Opcode, const SDLoc &DL,
   }
 }
 
+void SelectionDAG::copyExtraInfo(SDNode *From, SDNode *To) {
+  assert(From && To && "Invalid SDNode; empty source SDValue?");
+  auto I = SDEI.find(From);
+  if (I == SDEI.end())
+    return;
+  SDEI[To] = I->second;
+}
+
 #ifndef NDEBUG
 static void checkForCyclesHelper(const SDNode *N,
                                  SmallPtrSetImpl<const SDNode*> &Visited,
index f530d15..1fedcf1 100644 (file)
@@ -591,4 +591,33 @@ TEST_F(AArch64SelectionDAGTest, TestFold_STEP_VECTOR) {
   EXPECT_EQ(Op.getOpcode(), ISD::SPLAT_VECTOR);
 }
 
+TEST_F(AArch64SelectionDAGTest, ReplaceAllUsesWith) {
+  SDLoc Loc;
+  EVT IntVT = EVT::getIntegerVT(Context, 8);
+
+  SDValue N0 = DAG->getConstant(0x42, Loc, IntVT);
+  SDValue N1 = DAG->getRegister(0, IntVT);
+  // Construct node to fill arbitrary ExtraInfo.
+  SDValue N2 = DAG->getNode(ISD::SUB, Loc, IntVT, N0, N1);
+  EXPECT_FALSE(DAG->getHeapAllocSite(N2.getNode()));
+  EXPECT_FALSE(DAG->getNoMergeSiteInfo(N2.getNode()));
+  MDNode *MD = MDNode::get(Context, None);
+  DAG->addHeapAllocSite(N2.getNode(), MD);
+  DAG->addNoMergeSiteInfo(N2.getNode(), true);
+  EXPECT_EQ(DAG->getHeapAllocSite(N2.getNode()), MD);
+  EXPECT_TRUE(DAG->getNoMergeSiteInfo(N2.getNode()));
+
+  SDValue Root = DAG->getNode(ISD::ADD, Loc, IntVT, N2, N2);
+  EXPECT_EQ(Root->getOperand(0)->getOpcode(), ISD::SUB);
+  // Create new node and check that ExtraInfo is propagated on RAUW.
+  SDValue New = DAG->getNode(ISD::ADD, Loc, IntVT, N1, N1);
+  EXPECT_FALSE(DAG->getHeapAllocSite(New.getNode()));
+  EXPECT_FALSE(DAG->getNoMergeSiteInfo(New.getNode()));
+
+  DAG->ReplaceAllUsesWith(N2, New);
+  EXPECT_EQ(Root->getOperand(0), New);
+  EXPECT_EQ(DAG->getHeapAllocSite(New.getNode()), MD);
+  EXPECT_TRUE(DAG->getNoMergeSiteInfo(New.getNode()));
+}
+
 } // end namespace llvm