From 7f9a89f9a2cc55dbfc315aa11416fe3609918199 Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Sun, 9 May 2021 11:20:54 -0700 Subject: [PATCH] [ORC] Use the new dispatchTask API to run query callbacks. Dispatching query callbacks, rather than running them on the current thread, will allow them to be distributed across multiple threads. --- llvm/include/llvm/ExecutionEngine/Orc/Core.h | 7 ++---- llvm/lib/ExecutionEngine/Orc/Core.cpp | 29 +++++++++++++++++----- .../unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp | 28 ++++++++++----------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Core.h b/llvm/include/llvm/ExecutionEngine/Orc/Core.h index c37361f..f8dc039 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/Core.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/Core.h @@ -819,13 +819,10 @@ public: /// resolved. bool isComplete() const { return OutstandingSymbolsCount == 0; } - /// Call the NotifyComplete callback. - /// - /// This should only be called if all symbols covered by the query have - /// reached the specified state. - void handleComplete(); private: + void handleComplete(ExecutionSession &ES); + SymbolState getRequiredState() { return RequiredState; } void addQueryDependence(JITDylib &JD, SymbolStringPtr Name); diff --git a/llvm/lib/ExecutionEngine/Orc/Core.cpp b/llvm/lib/ExecutionEngine/Orc/Core.cpp index 4300a0b..270ef7c 100644 --- a/llvm/lib/ExecutionEngine/Orc/Core.cpp +++ b/llvm/lib/ExecutionEngine/Orc/Core.cpp @@ -170,13 +170,30 @@ void AsynchronousSymbolQuery::notifySymbolMetRequiredState( --OutstandingSymbolsCount; } -void AsynchronousSymbolQuery::handleComplete() { +void AsynchronousSymbolQuery::handleComplete(ExecutionSession &ES) { assert(OutstandingSymbolsCount == 0 && "Symbols remain, handleComplete called prematurely"); - auto TmpNotifyComplete = std::move(NotifyComplete); + class RunQueryCompleteTask : public Task { + public: + RunQueryCompleteTask(SymbolMap ResolvedSymbols, + SymbolsResolvedCallback NotifyComplete) + : ResolvedSymbols(std::move(ResolvedSymbols)), + NotifyComplete(std::move(NotifyComplete)) {} + void printDescription(raw_ostream &OS) override { + OS << "Execute query complete callback for " << ResolvedSymbols; + } + void run() override { NotifyComplete(std::move(ResolvedSymbols)); } + + private: + SymbolMap ResolvedSymbols; + SymbolsResolvedCallback NotifyComplete; + }; + + auto T = std::make_unique(std::move(ResolvedSymbols), + std::move(NotifyComplete)); NotifyComplete = SymbolsResolvedCallback(); - TmpNotifyComplete(std::move(ResolvedSymbols)); + ES.dispatchTask(std::move(T)); } void AsynchronousSymbolQuery::handleFailed(Error Err) { @@ -969,7 +986,7 @@ Error JITDylib::resolve(MaterializationResponsibility &MR, // Otherwise notify all the completed queries. for (auto &Q : CompletedQueries) { assert(Q->isComplete() && "Q not completed"); - Q->handleComplete(); + Q->handleComplete(ES); } return Error::success(); @@ -1120,7 +1137,7 @@ Error JITDylib::emit(MaterializationResponsibility &MR, // Otherwise notify all the completed queries. for (auto &Q : CompletedQueries) { assert(Q->isComplete() && "Q is not complete"); - Q->handleComplete(); + Q->handleComplete(ES); } return Error::success(); @@ -2541,7 +2558,7 @@ void ExecutionSession::OL_completeLookup( if (QueryComplete) { LLVM_DEBUG(dbgs() << "Completing query\n"); - Q->handleComplete(); + Q->handleComplete(*this); } dispatchOutstandingMUs(); diff --git a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp index 5128cc9..8935ea4 100644 --- a/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/CoreAPIsTest.cpp @@ -1019,12 +1019,11 @@ TEST_F(CoreAPIsStandardTest, TestBasicWeakSymbolMaterialization) { TEST_F(CoreAPIsStandardTest, DefineMaterializingSymbol) { bool ExpectNoMoreMaterialization = false; - ES.setDispatchTask( - [&](std::unique_ptr T) { - if (ExpectNoMoreMaterialization) - ADD_FAILURE() << "Unexpected materialization"; - T->run(); - }); + ES.setDispatchTask([&](std::unique_ptr T) { + if (ExpectNoMoreMaterialization && isa(*T)) + ADD_FAILURE() << "Unexpected materialization"; + T->run(); + }); auto MU = std::make_unique( SymbolFlagsMap({{Foo, FooSym.getFlags()}}), @@ -1250,14 +1249,11 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithUnthreadedMaterialization) { TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) { #if LLVM_ENABLE_THREADS - std::thread MaterializationThread; - ES.setDispatchTask( - [&](std::unique_ptr T) { - MaterializationThread = - std::thread([T = std::move(T)]() mutable { - T->run(); - }); - }); + std::vector WorkThreads; + ES.setDispatchTask([&](std::unique_ptr T) { + WorkThreads.push_back( + std::thread([T = std::move(T)]() mutable { T->run(); })); + }); cantFail(JD.define(absoluteSymbols({{Foo, FooSym}}))); @@ -1267,7 +1263,9 @@ TEST_F(CoreAPIsStandardTest, TestLookupWithThreadedMaterialization) { << "lookup returned an incorrect address"; EXPECT_EQ(FooLookupResult.getFlags(), FooSym.getFlags()) << "lookup returned incorrect flags"; - MaterializationThread.join(); + + for (auto &WT : WorkThreads) + WT.join(); #endif } -- 2.7.4