[ORC] Hold ResourceTracker in MaterializationResponsibility.
authorLang Hames <lhames@gmail.com>
Wed, 1 Dec 2021 23:48:50 +0000 (10:48 +1100)
committerLang Hames <lhames@gmail.com>
Thu, 2 Dec 2021 06:36:31 +0000 (17:36 +1100)
This keeps the tracker alive for the lifetime of the MR. This is needed so that
we can check whether the tracker has become defunct before posting results (or
failure) for the MR.

llvm/include/llvm/ExecutionEngine/Orc/Core.h
llvm/lib/ExecutionEngine/Orc/Core.cpp

index 362e8ab..d11a0c9 100644 (file)
@@ -519,6 +519,7 @@ private:
 /// symbols of an error.
 class MaterializationResponsibility {
   friend class ExecutionSession;
+  friend class JITDylib;
 
 public:
   MaterializationResponsibility(MaterializationResponsibility &&) = delete;
@@ -535,10 +536,10 @@ public:
 
   /// Returns the target JITDylib that these symbols are being materialized
   ///        into.
-  JITDylib &getTargetJITDylib() const { return *JD; }
+  JITDylib &getTargetJITDylib() const { return JD; }
 
   /// Returns the ExecutionSession for this instance.
-  ExecutionSession &getExecutionSession();
+  ExecutionSession &getExecutionSession() const;
 
   /// Returns the symbol flags map for this responsibility instance.
   /// Note: The returned flags may have transient flags (Lazy, Materializing)
@@ -640,15 +641,16 @@ public:
 private:
   /// Create a MaterializationResponsibility for the given JITDylib and
   ///        initial symbols.
-  MaterializationResponsibility(JITDylibSP JD, SymbolFlagsMap SymbolFlags,
+  MaterializationResponsibility(ResourceTrackerSP RT,
+                                SymbolFlagsMap SymbolFlags,
                                 SymbolStringPtr InitSymbol)
-      : JD(std::move(JD)), SymbolFlags(std::move(SymbolFlags)),
-        InitSymbol(std::move(InitSymbol)) {
-    assert(this->JD && "Cannot initialize with null JITDylib");
+      : JD(RT->getJITDylib()), RT(std::move(RT)),
+        SymbolFlags(std::move(SymbolFlags)), InitSymbol(std::move(InitSymbol)) {
     assert(!this->SymbolFlags.empty() && "Materializing nothing?");
   }
 
-  JITDylibSP JD;
+  JITDylib &JD;
+  ResourceTrackerSP RT;
   SymbolFlagsMap SymbolFlags;
   SymbolStringPtr InitSymbol;
 };
@@ -1151,7 +1153,6 @@ private:
 
   JITDylib(ExecutionSession &ES, std::string Name);
 
-  ResourceTrackerSP getTracker(MaterializationResponsibility &MR);
   std::pair<AsynchronousSymbolQuerySet, std::shared_ptr<SymbolDependenceMap>>
   removeTracker(ResourceTracker &RT);
 
@@ -1208,7 +1209,8 @@ private:
 
   // Map trackers to sets of symbols tracked.
   DenseMap<ResourceTracker *, SymbolNameVector> TrackerSymbols;
-  DenseMap<MaterializationResponsibility *, ResourceTracker *> MRTrackers;
+  DenseMap<ResourceTracker *, DenseSet<MaterializationResponsibility *>>
+      TrackerMRs;
 };
 
 /// Platforms set up standard symbols and mediate interactions between dynamic
@@ -1574,9 +1576,9 @@ private:
                                       SymbolStringPtr InitSymbol) {
     auto &JD = RT.getJITDylib();
     std::unique_ptr<MaterializationResponsibility> MR(
-        new MaterializationResponsibility(&JD, std::move(Symbols),
+        new MaterializationResponsibility(&RT, std::move(Symbols),
                                           std::move(InitSymbol)));
-    JD.MRTrackers[MR.get()] = &RT;
+    JD.TrackerMRs[&RT].insert(MR.get());
     return MR;
   }
 
@@ -1660,18 +1662,17 @@ private:
       JITDispatchHandlers;
 };
 
-inline ExecutionSession &MaterializationResponsibility::getExecutionSession() {
-  return JD->getExecutionSession();
+inline ExecutionSession &
+MaterializationResponsibility::getExecutionSession() const {
+  return JD.getExecutionSession();
 }
 
 template <typename Func>
 Error MaterializationResponsibility::withResourceKeyDo(Func &&F) const {
-  return JD->getExecutionSession().runSessionLocked([&]() -> Error {
-    auto I = JD->MRTrackers.find(this);
-    assert(I != JD->MRTrackers.end() && "No tracker for this MR");
-    if (I->second->isDefunct())
-      return make_error<ResourceTrackerDefunct>(I->second);
-    F(I->second->getKeyUnsafe());
+  return JD.getExecutionSession().runSessionLocked([&]() -> Error {
+    if (RT->isDefunct())
+      return make_error<ResourceTrackerDefunct>(RT);
+    F(RT->getKeyUnsafe());
     return Error::success();
   });
 }
@@ -1800,50 +1801,50 @@ private:
 // ---------------------------------------------
 
 inline MaterializationResponsibility::~MaterializationResponsibility() {
-  JD->getExecutionSession().OL_destroyMaterializationResponsibility(*this);
+  getExecutionSession().OL_destroyMaterializationResponsibility(*this);
 }
 
 inline SymbolNameSet MaterializationResponsibility::getRequestedSymbols() const {
-  return JD->getExecutionSession().OL_getRequestedSymbols(*this);
+  return getExecutionSession().OL_getRequestedSymbols(*this);
 }
 
 inline Error MaterializationResponsibility::notifyResolved(
     const SymbolMap &Symbols) {
-  return JD->getExecutionSession().OL_notifyResolved(*this, Symbols);
+  return getExecutionSession().OL_notifyResolved(*this, Symbols);
 }
 
 inline Error MaterializationResponsibility::notifyEmitted() {
-  return JD->getExecutionSession().OL_notifyEmitted(*this);
+  return getExecutionSession().OL_notifyEmitted(*this);
 }
 
 inline Error MaterializationResponsibility::defineMaterializing(
     SymbolFlagsMap SymbolFlags) {
-  return JD->getExecutionSession().OL_defineMaterializing(
-      *this, std::move(SymbolFlags));
+  return getExecutionSession().OL_defineMaterializing(*this,
+                                                      std::move(SymbolFlags));
 }
 
 inline void MaterializationResponsibility::failMaterialization() {
-  JD->getExecutionSession().OL_notifyFailed(*this);
+  getExecutionSession().OL_notifyFailed(*this);
 }
 
 inline Error MaterializationResponsibility::replace(
     std::unique_ptr<MaterializationUnit> MU) {
-  return JD->getExecutionSession().OL_replace(*this, std::move(MU));
+  return getExecutionSession().OL_replace(*this, std::move(MU));
 }
 
 inline Expected<std::unique_ptr<MaterializationResponsibility>>
 MaterializationResponsibility::delegate(const SymbolNameSet &Symbols) {
-  return JD->getExecutionSession().OL_delegate(*this, Symbols);
+  return getExecutionSession().OL_delegate(*this, Symbols);
 }
 
 inline void MaterializationResponsibility::addDependencies(
     const SymbolStringPtr &Name, const SymbolDependenceMap &Dependencies) {
-  JD->getExecutionSession().OL_addDependencies(*this, Name, Dependencies);
+  getExecutionSession().OL_addDependencies(*this, Name, Dependencies);
 }
 
 inline void MaterializationResponsibility::addDependenciesForAll(
     const SymbolDependenceMap &Dependencies) {
-  JD->getExecutionSession().OL_addDependenciesForAll(*this, Dependencies);
+  getExecutionSession().OL_addDependenciesForAll(*this, Dependencies);
 }
 
 } // End namespace orc
index 6b24d64..afd116b 100644 (file)
@@ -708,10 +708,8 @@ Error JITDylib::replace(MaterializationResponsibility &FromMR,
 
   auto Err =
       ES.runSessionLocked([&, this]() -> Error {
-        auto RT = getTracker(FromMR);
-
-        if (RT->isDefunct())
-          return make_error<ResourceTrackerDefunct>(std::move(RT));
+        if (FromMR.RT->isDefunct())
+          return make_error<ResourceTrackerDefunct>(std::move(FromMR.RT));
 
 #ifndef NDEBUG
         for (auto &KV : MU->getSymbols()) {
@@ -735,7 +733,8 @@ Error JITDylib::replace(MaterializationResponsibility &FromMR,
           if (MII != MaterializingInfos.end()) {
             if (MII->second.hasQueriesPending()) {
               MustRunMR = ES.createMaterializationResponsibility(
-                  *RT, std::move(MU->SymbolFlags), std::move(MU->InitSymbol));
+                  *FromMR.RT, std::move(MU->SymbolFlags),
+                  std::move(MU->InitSymbol));
               MustRunMU = std::move(MU);
               return Error::success();
             }
@@ -743,10 +742,8 @@ Error JITDylib::replace(MaterializationResponsibility &FromMR,
         }
 
         // Otherwise, make MU responsible for all the symbols.
-        auto RTI = MRTrackers.find(&FromMR);
-        assert(RTI != MRTrackers.end() && "No tracker for FromMR");
-        auto UMI =
-            std::make_shared<UnmaterializedInfo>(std::move(MU), RTI->second);
+        auto UMI = std::make_shared<UnmaterializedInfo>(std::move(MU),
+                                                        FromMR.RT.get());
         for (auto &KV : UMI->MU->getSymbols()) {
           auto SymI = Symbols.find(KV.first);
           assert(SymI->second.getState() == SymbolState::Materializing &&
@@ -787,13 +784,11 @@ JITDylib::delegate(MaterializationResponsibility &FromMR,
 
   return ES.runSessionLocked(
       [&]() -> Expected<std::unique_ptr<MaterializationResponsibility>> {
-        auto RT = getTracker(FromMR);
-
-        if (RT->isDefunct())
-          return make_error<ResourceTrackerDefunct>(std::move(RT));
+        if (FromMR.RT->isDefunct())
+          return make_error<ResourceTrackerDefunct>(std::move(FromMR.RT));
 
         return ES.createMaterializationResponsibility(
-            *RT, std::move(SymbolFlags), std::move(InitSymbol));
+            *FromMR.RT, std::move(SymbolFlags), std::move(InitSymbol));
       });
 }
 
@@ -903,10 +898,8 @@ Error JITDylib::resolve(MaterializationResponsibility &MR,
   AsynchronousSymbolQuerySet CompletedQueries;
 
   if (auto Err = ES.runSessionLocked([&, this]() -> Error {
-        auto RTI = MRTrackers.find(&MR);
-        assert(RTI != MRTrackers.end() && "No resource tracker for MR?");
-        if (RTI->second->isDefunct())
-          return make_error<ResourceTrackerDefunct>(RTI->second);
+        if (MR.RT->isDefunct())
+          return make_error<ResourceTrackerDefunct>(MR.RT);
 
         struct WorklistEntry {
           SymbolTable::iterator SymI;
@@ -1001,10 +994,8 @@ Error JITDylib::emit(MaterializationResponsibility &MR,
   DenseMap<JITDylib *, SymbolNameVector> ReadySymbols;
 
   if (auto Err = ES.runSessionLocked([&, this]() -> Error {
-        auto RTI = MRTrackers.find(&MR);
-        assert(RTI != MRTrackers.end() && "No resource tracker for MR?");
-        if (RTI->second->isDefunct())
-          return make_error<ResourceTrackerDefunct>(RTI->second);
+        if (MR.RT->isDefunct())
+          return make_error<ResourceTrackerDefunct>(MR.RT);
 
         SymbolNameSet SymbolsInErrorState;
         std::vector<SymbolTable::iterator> Worklist;
@@ -1149,9 +1140,12 @@ Error JITDylib::emit(MaterializationResponsibility &MR,
 void JITDylib::unlinkMaterializationResponsibility(
     MaterializationResponsibility &MR) {
   ES.runSessionLocked([&]() {
-    auto I = MRTrackers.find(&MR);
-    assert(I != MRTrackers.end() && "MaterializationResponsibility not linked");
-    MRTrackers.erase(I);
+    auto I = TrackerMRs.find(MR.RT.get());
+    assert(I != TrackerMRs.end() && "No MRs in TrackerMRs list for RT");
+    assert(I->second.count(&MR) && "MR not in TrackerMRs list for RT");
+    I->second.erase(&MR);
+    if (I->second.empty())
+      TrackerMRs.erase(MR.RT.get());
   });
 }
 
@@ -1454,13 +1448,6 @@ JITDylib::JITDylib(ExecutionSession &ES, std::string Name)
   LinkOrder.push_back({this, JITDylibLookupFlags::MatchAllSymbols});
 }
 
-ResourceTrackerSP JITDylib::getTracker(MaterializationResponsibility &MR) {
-  auto I = MRTrackers.find(&MR);
-  assert(I != MRTrackers.end() && "MR is not linked");
-  assert(I->second && "Linked tracker is null");
-  return I->second;
-}
-
 std::pair<JITDylib::AsynchronousSymbolQuerySet,
           std::shared_ptr<SymbolDependenceMap>>
 JITDylib::removeTracker(ResourceTracker &RT) {
@@ -1536,9 +1523,22 @@ void JITDylib::transferTracker(ResourceTracker &DstRT, ResourceTracker &SrcRT) {
   }
 
   // Update trackers for any active materialization responsibilities.
-  for (auto &KV : MRTrackers) {
-    if (KV.second == &SrcRT)
-      KV.second = &DstRT;
+  {
+    auto I = TrackerMRs.find(&SrcRT);
+    if (I != TrackerMRs.end()) {
+      auto &SrcMRs = I->second;
+      auto &DstMRs = TrackerMRs[&DstRT];
+      for (auto *MR : SrcMRs)
+        MR->RT = &DstRT;
+      if (DstMRs.empty())
+        DstMRs = std::move(SrcMRs);
+      else
+        for (auto *MR : SrcMRs)
+          DstMRs.insert(MR);
+      // Erase SrcRT entry in TrackerMRs. Use &SrcRT key rather than iterator I
+      // for this, since I may have been invalidated by 'TrackerMRs[&DstRT]'.
+      TrackerMRs.erase(&SrcRT);
+    }
   }
 
   // If we're transfering to the default tracker we just need to delete the
@@ -2635,11 +2635,9 @@ void ExecutionSession::OL_completeLookup(
                  << " MUs.\n";
         });
         for (auto &UMI : KV.second) {
-          std::unique_ptr<MaterializationResponsibility> MR(
-              new MaterializationResponsibility(
-                  &JD, std::move(UMI->MU->SymbolFlags),
-                  std::move(UMI->MU->InitSymbol)));
-          JD.MRTrackers[MR.get()] = UMI->RT;
+          auto MR = createMaterializationResponsibility(
+              *UMI->RT, std::move(UMI->MU->SymbolFlags),
+              std::move(UMI->MU->InitSymbol));
           OutstandingMUs.push_back(
               std::make_pair(std::move(UMI->MU), std::move(MR)));
         }
@@ -2757,18 +2755,18 @@ void ExecutionSession::OL_destroyMaterializationResponsibility(
 
   assert(MR.SymbolFlags.empty() &&
          "All symbols should have been explicitly materialized or failed");
-  MR.JD->unlinkMaterializationResponsibility(MR);
+  MR.JD.unlinkMaterializationResponsibility(MR);
 }
 
 SymbolNameSet ExecutionSession::OL_getRequestedSymbols(
     const MaterializationResponsibility &MR) {
-  return MR.JD->getRequestedSymbols(MR.SymbolFlags);
+  return MR.JD.getRequestedSymbols(MR.SymbolFlags);
 }
 
 Error ExecutionSession::OL_notifyResolved(MaterializationResponsibility &MR,
                                           const SymbolMap &Symbols) {
   LLVM_DEBUG({
-    dbgs() << "In " << MR.JD->getName() << " resolving " << Symbols << "\n";
+    dbgs() << "In " << MR.JD.getName() << " resolving " << Symbols << "\n";
   });
 #ifndef NDEBUG
   for (auto &KV : Symbols) {
@@ -2783,15 +2781,16 @@ Error ExecutionSession::OL_notifyResolved(MaterializationResponsibility &MR,
   }
 #endif
 
-  return MR.JD->resolve(MR, Symbols);
+  return MR.JD.resolve(MR, Symbols);
 }
 
 Error ExecutionSession::OL_notifyEmitted(MaterializationResponsibility &MR) {
   LLVM_DEBUG({
-    dbgs() << "In " << MR.JD->getName() << " emitting " << MR.SymbolFlags << "\n";
+    dbgs() << "In " << MR.JD.getName() << " emitting " << MR.SymbolFlags
+           << "\n";
   });
 
-  if (auto Err = MR.JD->emit(MR, MR.SymbolFlags))
+  if (auto Err = MR.JD.emit(MR, MR.SymbolFlags))
     return Err;
 
   MR.SymbolFlags.clear();
@@ -2802,10 +2801,11 @@ Error ExecutionSession::OL_defineMaterializing(
     MaterializationResponsibility &MR, SymbolFlagsMap NewSymbolFlags) {
 
   LLVM_DEBUG({
-    dbgs() << "In " << MR.JD->getName() << " defining materializing symbols "
+    dbgs() << "In " << MR.JD.getName() << " defining materializing symbols "
            << NewSymbolFlags << "\n";
   });
-  if (auto AcceptedDefs = MR.JD->defineMaterializing(std::move(NewSymbolFlags))) {
+  if (auto AcceptedDefs =
+          MR.JD.defineMaterializing(std::move(NewSymbolFlags))) {
     // Add all newly accepted symbols to this responsibility object.
     for (auto &KV : *AcceptedDefs)
       MR.SymbolFlags.insert(KV);
@@ -2817,14 +2817,14 @@ Error ExecutionSession::OL_defineMaterializing(
 void ExecutionSession::OL_notifyFailed(MaterializationResponsibility &MR) {
 
   LLVM_DEBUG({
-    dbgs() << "In " << MR.JD->getName() << " failing materialization for "
+    dbgs() << "In " << MR.JD.getName() << " failing materialization for "
            << MR.SymbolFlags << "\n";
   });
 
   JITDylib::FailedSymbolsWorklist Worklist;
 
   for (auto &KV : MR.SymbolFlags)
-    Worklist.push_back(std::make_pair(MR.JD.get(), KV.first));
+    Worklist.push_back(std::make_pair(&MR.JD, KV.first));
   MR.SymbolFlags.clear();
 
   if (Worklist.empty())
@@ -2834,9 +2834,8 @@ void ExecutionSession::OL_notifyFailed(MaterializationResponsibility &MR) {
   std::shared_ptr<SymbolDependenceMap> FailedSymbols;
 
   runSessionLocked([&]() {
-    auto RTI = MR.JD->MRTrackers.find(&MR);
-    assert(RTI != MR.JD->MRTrackers.end() && "No tracker for this");
-    if (RTI->second->isDefunct())
+    // If the tracker is defunct then there's nothing to do here.
+    if (MR.RT->isDefunct())
       return;
 
     std::tie(FailedQueries, FailedSymbols) =
@@ -2858,12 +2857,12 @@ Error ExecutionSession::OL_replace(MaterializationResponsibility &MR,
   if (MU->getInitializerSymbol() == MR.InitSymbol)
     MR.InitSymbol = nullptr;
 
-  LLVM_DEBUG(MR.JD->getExecutionSession().runSessionLocked([&]() {
-    dbgs() << "In " << MR.JD->getName() << " replacing symbols with " << *MU
+  LLVM_DEBUG(MR.JD.getExecutionSession().runSessionLocked([&]() {
+    dbgs() << "In " << MR.JD.getName() << " replacing symbols with " << *MU
            << "\n";
   }););
 
-  return MR.JD->replace(MR, std::move(MU));
+  return MR.JD.replace(MR, std::move(MU));
 }
 
 Expected<std::unique_ptr<MaterializationResponsibility>>
@@ -2886,8 +2885,8 @@ ExecutionSession::OL_delegate(MaterializationResponsibility &MR,
     MR.SymbolFlags.erase(I);
   }
 
-  return MR.JD->delegate(MR, std::move(DelegatedFlags),
-                         std::move(DelegatedInitSymbol));
+  return MR.JD.delegate(MR, std::move(DelegatedFlags),
+                        std::move(DelegatedInitSymbol));
 }
 
 void ExecutionSession::OL_addDependencies(
@@ -2899,7 +2898,7 @@ void ExecutionSession::OL_addDependencies(
   });
   assert(MR.SymbolFlags.count(Name) &&
          "Symbol not covered by this MaterializationResponsibility instance");
-  MR.JD->addDependencies(Name, Dependencies);
+  MR.JD.addDependencies(Name, Dependencies);
 }
 
 void ExecutionSession::OL_addDependenciesForAll(
@@ -2910,7 +2909,7 @@ void ExecutionSession::OL_addDependenciesForAll(
            << Dependencies << "\n";
   });
   for (auto &KV : MR.SymbolFlags)
-    MR.JD->addDependencies(KV.first, Dependencies);
+    MR.JD.addDependencies(KV.first, Dependencies);
 }
 
 #ifndef NDEBUG