[ORC] Simplify VSO::lookupFlags to return the flags map.
authorLang Hames <lhames@gmail.com>
Fri, 20 Jul 2018 18:31:52 +0000 (18:31 +0000)
committerLang Hames <lhames@gmail.com>
Fri, 20 Jul 2018 18:31:52 +0000 (18:31 +0000)
This discards the unresolved symbols set and returns the flags map directly
(rather than mutating it via the first argument).

The unresolved symbols result made it easy to chain lookupFlags calls, but such
chaining should be rare to non-existant (especially now that symbol resolvers
are being deprecated) so the simpler method signature is preferable.

llvm-svn: 337594

13 files changed:
llvm/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h
llvm/include/llvm/ExecutionEngine/Orc/Core.h
llvm/include/llvm/ExecutionEngine/Orc/Legacy.h
llvm/include/llvm/ExecutionEngine/Orc/NullResolver.h
llvm/lib/ExecutionEngine/Orc/Core.cpp
llvm/lib/ExecutionEngine/Orc/Legacy.cpp
llvm/lib/ExecutionEngine/Orc/NullResolver.cpp
llvm/lib/ExecutionEngine/Orc/OrcCBindingsStack.h
llvm/lib/ExecutionEngine/Orc/OrcMCJITReplacement.h
llvm/lib/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.cpp
llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp
llvm/unittests/ExecutionEngine/Orc/LegacyAPIInteropTest.cpp
llvm/unittests/ExecutionEngine/Orc/RTDyldObjectLinkingLayerTest.cpp

index 3a720952df786472b23e91907188c8a1f3d369ad..8bd21a0e3dd6786cb3e87992d9e6aa5dc97b1752 100644 (file)
@@ -499,20 +499,29 @@ private:
     };
 
     auto GVsResolver = createSymbolResolver(
-        [&LD, LegacyLookup](SymbolFlagsMap &SymbolFlags,
-                            const SymbolNameSet &Symbols) {
-          auto NotFoundViaLegacyLookup =
-              lookupFlagsWithLegacyFn(SymbolFlags, Symbols, LegacyLookup);
+        [&LD, LegacyLookup](const SymbolNameSet &Symbols) {
+          auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup);
 
-          if (!NotFoundViaLegacyLookup) {
-            logAllUnhandledErrors(NotFoundViaLegacyLookup.takeError(), errs(),
+          if (!SymbolFlags) {
+            logAllUnhandledErrors(SymbolFlags.takeError(), errs(),
                                   "CODLayer/GVsResolver flags lookup failed: ");
-            SymbolFlags.clear();
-            return SymbolNameSet();
+            return SymbolFlagsMap();
           }
 
-          return LD.BackingResolver->lookupFlags(SymbolFlags,
-                                                 *NotFoundViaLegacyLookup);
+          if (SymbolFlags->size() == Symbols.size())
+            return *SymbolFlags;
+
+          SymbolNameSet NotFoundViaLegacyLookup;
+          for (auto &S : Symbols)
+            if (!SymbolFlags->count(S))
+              NotFoundViaLegacyLookup.insert(S);
+          auto SymbolFlags2 =
+              LD.BackingResolver->lookupFlags(NotFoundViaLegacyLookup);
+
+          for (auto &KV : SymbolFlags2)
+            (*SymbolFlags)[KV.first] = std::move(KV.second);
+
+          return *SymbolFlags;
         },
         [this, &LD,
          LegacyLookup](std::shared_ptr<AsynchronousSymbolQuery> Query,
@@ -659,18 +668,29 @@ private:
 
     // Create memory manager and symbol resolver.
     auto Resolver = createSymbolResolver(
-        [&LD, LegacyLookup](SymbolFlagsMap &SymbolFlags,
-                            const SymbolNameSet &Symbols) {
-          auto NotFoundViaLegacyLookup =
-              lookupFlagsWithLegacyFn(SymbolFlags, Symbols, LegacyLookup);
-          if (!NotFoundViaLegacyLookup) {
-            logAllUnhandledErrors(NotFoundViaLegacyLookup.takeError(), errs(),
+        [&LD, LegacyLookup](const SymbolNameSet &Symbols) {
+          auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup);
+          if (!SymbolFlags) {
+            logAllUnhandledErrors(SymbolFlags.takeError(), errs(),
                                   "CODLayer/SubResolver flags lookup failed: ");
-            SymbolFlags.clear();
-            return SymbolNameSet();
+            return SymbolFlagsMap();
           }
-          return LD.BackingResolver->lookupFlags(SymbolFlags,
-                                                 *NotFoundViaLegacyLookup);
+
+          if (SymbolFlags->size() == Symbols.size())
+            return *SymbolFlags;
+
+          SymbolNameSet NotFoundViaLegacyLookup;
+          for (auto &S : Symbols)
+            if (!SymbolFlags->count(S))
+              NotFoundViaLegacyLookup.insert(S);
+
+          auto SymbolFlags2 =
+              LD.BackingResolver->lookupFlags(NotFoundViaLegacyLookup);
+
+          for (auto &KV : SymbolFlags2)
+            (*SymbolFlags)[KV.first] = std::move(KV.second);
+
+          return *SymbolFlags;
         },
         [this, &LD, LegacyLookup](std::shared_ptr<AsynchronousSymbolQuery> Q,
                                   SymbolNameSet Symbols) {
index 5894525e0eba424b9b6324b4dcbedb68d375ff5d..5060f84c2413d4e7efd22cec0cc8a767c3114991 100644 (file)
@@ -577,7 +577,7 @@ public:
 
   /// Search the given VSO for the symbols in Symbols. If found, store
   ///        the flags for each symbol in Flags. Returns any unresolved symbols.
-  SymbolNameSet lookupFlags(SymbolFlagsMap &Flags, const SymbolNameSet &Names);
+  SymbolFlagsMap lookupFlags(const SymbolNameSet &Names);
 
   /// Search the given VSOs in order for the symbols in Symbols. Results
   ///        (once they become available) will be returned via the given Query.
index d5ece3e5a7f3fd3424bcddb7c2be0a7a8d011e43..e97f98edcdfabc2e9ba084cb8bcefc729d202e1e 100644 (file)
@@ -33,8 +33,7 @@ public:
 
   /// Returns the flags for each symbol in Symbols that can be found,
   ///        along with the set of symbol that could not be found.
-  virtual SymbolNameSet lookupFlags(SymbolFlagsMap &Flags,
-                                    const SymbolNameSet &Symbols) = 0;
+  virtual SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) = 0;
 
   /// For each symbol in Symbols that can be found, assigns that symbols
   ///        value in Query. Returns the set of symbols that could not be found.
@@ -55,9 +54,8 @@ public:
       : LookupFlags(std::forward<LookupFlagsFnRef>(LookupFlags)),
         Lookup(std::forward<LookupFnRef>(Lookup)) {}
 
-  SymbolNameSet lookupFlags(SymbolFlagsMap &Flags,
-                            const SymbolNameSet &Symbols) final {
-    return LookupFlags(Flags, Symbols);
+  SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) final {
+    return LookupFlags(Symbols);
   }
 
   SymbolNameSet lookup(std::shared_ptr<AsynchronousSymbolQuery> Query,
@@ -111,21 +109,18 @@ private:
 ///
 /// Useful for implementing lookupFlags bodies that query legacy resolvers.
 template <typename FindSymbolFn>
-Expected<SymbolNameSet> lookupFlagsWithLegacyFn(SymbolFlagsMap &SymbolFlags,
-                                                const SymbolNameSet &Symbols,
-                                                FindSymbolFn FindSymbol) {
-  SymbolNameSet SymbolsNotFound;
+Expected<SymbolFlagsMap> lookupFlagsWithLegacyFn(const SymbolNameSet &Symbols,
+                                                 FindSymbolFn FindSymbol) {
+  SymbolFlagsMap SymbolFlags;
 
   for (auto &S : Symbols) {
     if (JITSymbol Sym = FindSymbol(*S))
       SymbolFlags[S] = Sym.getFlags();
     else if (auto Err = Sym.takeError())
       return std::move(Err);
-    else
-      SymbolsNotFound.insert(S);
   }
 
-  return SymbolsNotFound;
+  return SymbolFlags;
 }
 
 /// Use the given legacy-style FindSymbol function (i.e. a function that
@@ -182,14 +177,12 @@ public:
       : ES(ES), LegacyLookup(std::move(LegacyLookup)),
         ReportError(std::move(ReportError)) {}
 
-  SymbolNameSet lookupFlags(SymbolFlagsMap &Flags,
-                            const SymbolNameSet &Symbols) final {
-    if (auto RemainingSymbols =
-            lookupFlagsWithLegacyFn(Flags, Symbols, LegacyLookup))
-      return std::move(*RemainingSymbols);
+  SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) final {
+    if (auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup))
+      return std::move(*SymbolFlags);
     else {
-      ReportError(RemainingSymbols.takeError());
-      return Symbols;
+      ReportError(SymbolFlags.takeError());
+      return SymbolFlagsMap();
     }
   }
 
index bfb9931df143a3b8b05071e257fed5cedf649a4f..3dd3cfe05b8d1e026b35536130b83a047addbfa5 100644 (file)
@@ -23,8 +23,7 @@ namespace orc {
 
 class NullResolver : public SymbolResolver {
 public:
-  SymbolNameSet lookupFlags(SymbolFlagsMap &Flags,
-                            const SymbolNameSet &Symbols) override;
+  SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) override;
 
   SymbolNameSet lookup(std::shared_ptr<AsynchronousSymbolQuery> Query,
                        SymbolNameSet Symbols) override;
index 9275bb84d01afc728af394c14e8461293dd096e7..65f1178228848ab14de4eaf29210f4a68ef85c4b 100644 (file)
@@ -522,11 +522,14 @@ ReExportsMaterializationUnit::extractFlags(const SymbolAliasMap &Aliases) {
 
 Expected<SymbolAliasMap>
 buildSimpleReexportsAliasMap(VSO &SourceV, const SymbolNameSet &Symbols) {
-  SymbolFlagsMap Flags;
-  auto Unresolved = SourceV.lookupFlags(Flags, Symbols);
+  auto Flags = SourceV.lookupFlags(Symbols);
 
-  if (!Unresolved.empty())
+  if (Flags.size() != Symbols.size()) {
+    SymbolNameSet Unresolved = Symbols;
+    for (auto &KV : Flags)
+      Unresolved.erase(KV.first);
     return make_error<SymbolsNotFound>(std::move(Unresolved));
+  }
 
   SymbolAliasMap Result;
   for (auto &Name : Symbols) {
@@ -900,22 +903,20 @@ void VSO::removeFromSearchOrder(VSO &V) {
   });
 }
 
-SymbolNameSet VSO::lookupFlags(SymbolFlagsMap &Flags,
-                               const SymbolNameSet &Names) {
+SymbolFlagsMap VSO::lookupFlags(const SymbolNameSet &Names) {
   return ES.runSessionLocked([&, this]() {
-    auto Unresolved = lookupFlagsImpl(Flags, Names);
+    SymbolFlagsMap Result;
+    auto Unresolved = lookupFlagsImpl(Result, Names);
     if (FallbackDefinitionGenerator && !Unresolved.empty()) {
       auto FallbackDefs = FallbackDefinitionGenerator(*this, Unresolved);
       if (!FallbackDefs.empty()) {
-        auto Unresolved2 = lookupFlagsImpl(Flags, FallbackDefs);
+        auto Unresolved2 = lookupFlagsImpl(Result, FallbackDefs);
         (void)Unresolved2;
         assert(Unresolved2.empty() &&
                "All fallback defs should have been found by lookupFlagsImpl");
-        for (auto &D : FallbackDefs)
-          Unresolved.erase(D);
       }
     };
-    return Unresolved;
+    return Result;
   });
 }
 
index 79525baf92e4d34b18ea6396b01efba7f5a5adb4..6fde6898a16c5d1acba901716a44867a0bb9fd00 100644 (file)
@@ -48,8 +48,7 @@ JITSymbolResolverAdapter::lookupFlags(const LookupSet &Symbols) {
   for (auto &S : Symbols)
     InternedSymbols.insert(ES.getSymbolStringPool().intern(S));
 
-  SymbolFlagsMap SymbolFlags;
-  R.lookupFlags(SymbolFlags, InternedSymbols);
+  SymbolFlagsMap SymbolFlags = R.lookupFlags(InternedSymbols);
   LookupFlagsResult Result;
   for (auto &KV : SymbolFlags) {
     ResolvedStrings.insert(KV.first);
index 872efea5fc8a71ec81ca4ba45895f10c5f99c030..3796e3d37bc223521cecb0cd0ded30b99a4ce0e1 100644 (file)
@@ -14,9 +14,8 @@
 namespace llvm {
 namespace orc {
 
-SymbolNameSet NullResolver::lookupFlags(SymbolFlagsMap &Flags,
-                                        const SymbolNameSet &Symbols) {
-  return Symbols;
+SymbolFlagsMap NullResolver::lookupFlags(const SymbolNameSet &Symbols) {
+  return SymbolFlagsMap();
 }
 
 SymbolNameSet
index aa63957236e041d2c92c545e02c76a3252ee7dda..6c44f4367ec0587abbc974b9ef4c422c46db4c86 100644 (file)
@@ -129,21 +129,20 @@ private:
         : Stack(Stack), ExternalResolver(std::move(ExternalResolver)),
           ExternalResolverCtx(std::move(ExternalResolverCtx)) {}
 
-    orc::SymbolNameSet lookupFlags(orc::SymbolFlagsMap &SymbolFlags,
-                                   const orc::SymbolNameSet &Symbols) override {
-      orc::SymbolNameSet SymbolsNotFound;
+    orc::SymbolFlagsMap
+    lookupFlags(const orc::SymbolNameSet &Symbols) override {
+      orc::SymbolFlagsMap SymbolFlags;
 
       for (auto &S : Symbols) {
         if (auto Sym = findSymbol(*S))
           SymbolFlags[S] = Sym.getFlags();
         else if (auto Err = Sym.takeError()) {
           Stack.reportError(std::move(Err));
-          return orc::SymbolNameSet();
-        } else
-          SymbolsNotFound.insert(S);
+          return orc::SymbolFlagsMap();
+        }
       }
 
-      return SymbolsNotFound;
+      return SymbolFlags;
     }
 
     orc::SymbolNameSet
index 922ec4762044022c2693d30eb8c0fbb7f91a7a68..ded53ac3106b4cf53891ef020258ee4b729a084b 100644 (file)
@@ -144,28 +144,26 @@ class OrcMCJITReplacement : public ExecutionEngine {
   public:
     LinkingORCResolver(OrcMCJITReplacement &M) : M(M) {}
 
-    SymbolNameSet lookupFlags(SymbolFlagsMap &SymbolFlags,
-                              const SymbolNameSet &Symbols) override {
-      SymbolNameSet UnresolvedSymbols;
+    SymbolFlagsMap lookupFlags(const SymbolNameSet &Symbols) override {
+      SymbolFlagsMap SymbolFlags;
 
       for (auto &S : Symbols) {
         if (auto Sym = M.findMangledSymbol(*S)) {
           SymbolFlags[S] = Sym.getFlags();
         } else if (auto Err = Sym.takeError()) {
           M.reportError(std::move(Err));
-          return SymbolNameSet();
+          return SymbolFlagsMap();
         } else {
           if (auto Sym2 = M.ClientResolver->findSymbolInLogicalDylib(*S)) {
             SymbolFlags[S] = Sym2.getFlags();
           } else if (auto Err = Sym2.takeError()) {
             M.reportError(std::move(Err));
-            return SymbolNameSet();
-          } else
-            UnresolvedSymbols.insert(S);
+            return SymbolFlagsMap();
+          }
         }
       }
 
-      return UnresolvedSymbols;
+      return SymbolFlags;
     }
 
     SymbolNameSet lookup(std::shared_ptr<AsynchronousSymbolQuery> Query,
index 8c53b4f58de23c3ad7bdd9f94ecda143a44cf12c..7cdc6b352d11c8416988bf52dde47f0684f11f73 100644 (file)
@@ -64,7 +64,7 @@ public:
         return;
 
       assert(VSOs.front() && "VSOList entry can not be null");
-      VSOs.front()->lookupFlags(InternedResult, InternedSymbols);
+      InternedResult = VSOs.front()->lookupFlags(InternedSymbols);
     });
 
     LookupFlagsResult Result;
index 9d4f9463372f1484aa6d7c909c1ef1128152285c..c0afbc6be06a84bffac0b2094d0077b10ce476df 100644 (file)
@@ -215,11 +215,8 @@ TEST_F(CoreAPIsStandardTest, LookupFlagsTest) {
 
   SymbolNameSet Names({Foo, Bar, Baz});
 
-  SymbolFlagsMap SymbolFlags;
-  auto SymbolsNotFound = V.lookupFlags(SymbolFlags, Names);
+  auto SymbolFlags = V.lookupFlags(Names);
 
-  EXPECT_EQ(SymbolsNotFound.size(), 1U) << "Expected one not-found symbol";
-  EXPECT_EQ(SymbolsNotFound.count(Baz), 1U) << "Expected Baz to be not-found";
   EXPECT_EQ(SymbolFlags.size(), 2U)
       << "Returned symbol flags contains unexpected results";
   EXPECT_EQ(SymbolFlags.count(Foo), 1U) << "Missing lookupFlags result for Foo";
index 51f86eacfd95a71c09cb3df05842a498b59f7cb0..596584b7117e3b08f489b5e1d48e57d4dff47102 100644 (file)
@@ -22,17 +22,14 @@ TEST_F(LegacyAPIsStandardTest, TestLambdaSymbolResolver) {
   cantFail(V.define(absoluteSymbols({{Foo, FooSym}, {Bar, BarSym}})));
 
   auto Resolver = createSymbolResolver(
-      [&](SymbolFlagsMap &SymbolFlags, const SymbolNameSet &Symbols) {
-        return V.lookupFlags(SymbolFlags, Symbols);
-      },
+      [&](const SymbolNameSet &Symbols) { return V.lookupFlags(Symbols); },
       [&](std::shared_ptr<AsynchronousSymbolQuery> Q, SymbolNameSet Symbols) {
         return V.lookup(std::move(Q), Symbols);
       });
 
   SymbolNameSet Symbols({Foo, Bar, Baz});
 
-  SymbolFlagsMap SymbolFlags;
-  SymbolNameSet SymbolsNotFound = Resolver->lookupFlags(SymbolFlags, Symbols);
+  SymbolFlagsMap SymbolFlags = Resolver->lookupFlags(Symbols);
 
   EXPECT_EQ(SymbolFlags.size(), 2U)
       << "lookupFlags returned the wrong number of results";
@@ -42,10 +39,6 @@ TEST_F(LegacyAPIsStandardTest, TestLambdaSymbolResolver) {
       << "Incorrect lookupFlags result for Foo";
   EXPECT_EQ(SymbolFlags[Bar], BarSym.getFlags())
       << "Incorrect lookupFlags result for Bar";
-  EXPECT_EQ(SymbolsNotFound.size(), 1U)
-      << "Expected one symbol not found in lookupFlags";
-  EXPECT_EQ(SymbolsNotFound.count(Baz), 1U)
-      << "Expected baz not to be found in lookupFlags";
 
   bool OnResolvedRun = false;
 
@@ -86,9 +79,8 @@ TEST(LegacyAPIInteropTest, QueryAgainstVSO) {
   JITEvaluatedSymbol FooSym(0xdeadbeef, JITSymbolFlags::Exported);
   cantFail(V.define(absoluteSymbols({{Foo, FooSym}})));
 
-  auto LookupFlags = [&](SymbolFlagsMap &SymbolFlags,
-                         const SymbolNameSet &Names) {
-    return V.lookupFlags(SymbolFlags, Names);
+  auto LookupFlags = [&](const SymbolNameSet &Names) {
+    return V.lookupFlags(Names);
   };
 
   auto Lookup = [&](std::shared_ptr<AsynchronousSymbolQuery> Query,
@@ -153,19 +145,14 @@ TEST(LegacyAPIInteropTset, LegacyLookupHelpersFn) {
 
   SymbolNameSet Symbols({Foo, Bar, Baz});
 
-  SymbolFlagsMap SymbolFlags;
-  auto SymbolsNotFound =
-      lookupFlagsWithLegacyFn(SymbolFlags, Symbols, LegacyLookup);
-
-  EXPECT_TRUE(!!SymbolsNotFound) << "lookupFlagsWithLegacy failed unexpectedly";
-  EXPECT_EQ(SymbolFlags.size(), 2U) << "Wrong number of flags returned";
-  EXPECT_EQ(SymbolFlags.count(Foo), 1U) << "Flags for foo missing";
-  EXPECT_EQ(SymbolFlags.count(Bar), 1U) << "Flags for foo missing";
-  EXPECT_EQ(SymbolFlags[Foo], FooFlags) << "Wrong flags for foo";
-  EXPECT_EQ(SymbolFlags[Bar], BarFlags) << "Wrong flags for foo";
-  EXPECT_EQ(SymbolsNotFound->size(), 1U) << "Expected one symbol not found";
-  EXPECT_EQ(SymbolsNotFound->count(Baz), 1U)
-      << "Expected symbol baz to be not found";
+  auto SymbolFlags = lookupFlagsWithLegacyFn(Symbols, LegacyLookup);
+
+  EXPECT_TRUE(!!SymbolFlags) << "Expected lookupFlagsWithLegacyFn to succeed";
+  EXPECT_EQ(SymbolFlags->size(), 2U) << "Wrong number of flags returned";
+  EXPECT_EQ(SymbolFlags->count(Foo), 1U) << "Flags for foo missing";
+  EXPECT_EQ(SymbolFlags->count(Bar), 1U) << "Flags for foo missing";
+  EXPECT_EQ((*SymbolFlags)[Foo], FooFlags) << "Wrong flags for foo";
+  EXPECT_EQ((*SymbolFlags)[Bar], BarFlags) << "Wrong flags for foo";
   EXPECT_FALSE(BarMaterialized)
       << "lookupFlags should not have materialized bar";
 
index 17f733c99d201feedb761d575068929437e9b1b0..420631c36ad2dd892dc72b26bd4564c07dda7544 100644 (file)
@@ -184,9 +184,8 @@ TEST_F(RTDyldObjectLinkingLayerExecutionTest, NoDuplicateFinalization) {
   };
 
   Resolvers[K2] = createSymbolResolver(
-      [&](SymbolFlagsMap &SymbolFlags, const SymbolNameSet &Symbols) {
-        return cantFail(
-            lookupFlagsWithLegacyFn(SymbolFlags, Symbols, LegacyLookup));
+      [&](const SymbolNameSet &Symbols) {
+        return cantFail(lookupFlagsWithLegacyFn(Symbols, LegacyLookup));
       },
       [&](std::shared_ptr<AsynchronousSymbolQuery> Query,
           const SymbolNameSet &Symbols) {