1 // Copyright (C) 2018 Intel Corporation
3 // SPDX-License-Identifier: Apache-2.0
6 #include "execution_engine/execution_engine.hpp"
7 #include "execution_engine/backend.hpp"
8 #include "execution_engine/executable.hpp"
12 #include "util/assert.hpp"
13 #include "util/range.hpp"
14 #include "util/checked_cast.hpp"
19 class ExecutableImpl final : public Executable
21 util::any m_dummy_opaque;
24 void addExec(std::unique_ptr<BackendExecutable>&& exec);
26 virtual void run() override;
27 virtual void run(util::any &opaque) override;
29 virtual void runAsync() override;
30 virtual void runAsync(util::any &opaque) override;
31 virtual void wait() override;
34 std::unique_ptr<BackendExecutable> mainExec;
35 std::vector<std::unique_ptr<BackendExecutable>> execs;
38 void BackendExecutable::run(util::any &)
40 // Default implementation calls run() (backward compatibility)
44 void BackendExecutable::runAsync(util::any &)
46 // Default implementation calls runAsync() (backward compatibility)
50 ExecutionEngine::ExecutionEngine()
55 ExecutionEngine::~ExecutionEngine()
60 void ExecutionEngine::addPrePassCallback(ExecutionEngine::PassCallback callback)
62 ASSERT(nullptr != callback);
63 m_prePassCallbacks.callbacks.emplace_back(std::move(callback));
66 void ExecutionEngine::addPostPassCallback(ExecutionEngine::PassCallback callback)
68 ASSERT(nullptr != callback);
69 m_postPassCallbacks.callbacks.emplace_back(std::move(callback));
72 void ExecutionEngine::addBackend(std::unique_ptr<ExecutionBackend>&& backend)
74 ASSERT(nullptr != backend);
75 ASSERT(m_backends.end() == std::find(m_backends.begin(), m_backends.end(), backend));
76 m_backends.emplace_back(std::move(backend));
79 void ExecutionEngine::setupBackends()
81 ExecutionEngineSetupContext context(*this);
82 for (auto& b: m_backends)
84 b->setupExecutionEngine(context);
90 struct GraphListenerSetter final
93 IGraphListener* listener = nullptr;
94 GraphListenerSetter(Graph& gr, IGraphListener* l):
95 graph(gr), listener(l)
97 ASSERT(nullptr == graph.getListener());
98 graph.setListener(listener);
100 ~GraphListenerSetter()
102 ASSERT(listener == graph.getListener());
103 graph.setListener(nullptr);
106 GraphListenerSetter(const GraphListenerSetter&) = delete;
107 GraphListenerSetter& operator=(const GraphListenerSetter&) = delete;
111 void ExecutionEngine::runPasses(Graph& graph)
113 m_lazyPasses.reset();
114 GraphListenerSetter setter(graph, m_lazyPasses.getListener());
115 passes::PassContext context{graph};
116 m_passManager.run(context);
117 for (auto& str: m_executableDependencies)
119 ASSERT(!str.empty());
120 auto pass = m_lazyPasses.getPass(str);
121 ASSERT(nullptr != pass);
122 pass->process(context);
126 std::unique_ptr<Executable> ExecutionEngine::createExecutable(const Graph& graph)
128 std::unique_ptr<ExecutableImpl> ret;
129 for (auto& b : m_backends)
131 std::unique_ptr<BackendExecutable> bexec(b->createExecutable(graph));
132 if (nullptr != bexec)
136 ret.reset(new ExecutableImpl);
138 ret->addExec(std::move(bexec));
142 return std::move(ret);
145 void ExecutionEngine::addExecutableDependency(const std::string& lazyPassName)
147 ASSERT(!lazyPassName.empty());
148 ASSERT(m_lazyPasses.getPass(lazyPassName) != nullptr);
149 m_executableDependencies.emplace(lazyPassName);
152 void ExecutionEngine::addPassStage(const std::string& stageName)
154 ASSERT(!stageName.empty());
155 m_passManager.addStage(stageName);
158 void ExecutionEngine::addPassStage(const std::string& stageName, const std::string& prevStage)
160 ASSERT(!stageName.empty());
161 ASSERT(!prevStage.empty());
162 m_passManager.addStage(stageName, prevStage);
165 ExecutionEngine::StagesRange ExecutionEngine::passStages() const
167 return util::map<PassMapper>(m_passManager.stages());
170 void ExecutionEngine::prePass(const PassDesc& desc,
171 const passes::PassContext& context)
173 m_prePassCallbacks.call(desc, context);
176 void ExecutionEngine::postPass(const PassDesc& desc,
177 const passes::PassContext& context)
179 m_postPassCallbacks.call(desc, context);
182 void ExecutableImpl::addExec(std::unique_ptr<BackendExecutable>&& exec)
184 ASSERT(nullptr != exec);
185 if (nullptr == mainExec)
187 mainExec = std::move(exec);
191 execs.emplace_back(std::move(exec));
195 struct ExecExceptionHandler
197 size_t passedCount = 0;
198 std::vector<std::unique_ptr<BackendExecutable>> &handledVector;
199 ExecExceptionHandler(std::vector<std::unique_ptr<BackendExecutable>> &execs) : handledVector(execs) {}
200 ~ExecExceptionHandler()
202 ASSERT(handledVector.size() >= passedCount);
203 auto count = util::checked_cast<int>(handledVector.size() - passedCount);
204 for (auto i = util::checked_cast<int>(handledVector.size()) - 1;
208 handledVector[i]->cancel();
213 void ExecutableImpl::run()
215 // Since run() takes a modifiable `any`, reset it before the run
216 m_dummy_opaque = util::any();
220 void ExecutableImpl::run(util::any &opaque)
222 ASSERT(nullptr != mainExec);
223 ExecExceptionHandler handler(execs);
224 for (auto& e: util::toRangeReverse(execs))
227 handler.passedCount++;
230 mainExec->run(opaque);
232 for (auto& e: util::toRange(execs))
234 handler.passedCount--;
239 void ExecutableImpl::runAsync()
241 // Since runAsync() takes a modifiable `any`, reset it before the run
242 m_dummy_opaque = util::any();
243 runAsync(m_dummy_opaque);
246 void ExecutableImpl::runAsync(util::any &opaque)
248 ASSERT(nullptr != mainExec);
249 for (auto& e: util::toRangeReverse(execs))
253 mainExec->runAsync(opaque);
256 void ExecutableImpl::wait()
258 ASSERT(nullptr != mainExec);
259 ExecExceptionHandler handler(execs);
260 handler.passedCount = util::checked_cast<decltype(handler.passedCount)>(execs.size());
262 for (auto& e: util::toRange(execs))
264 handler.passedCount--;
269 ExecutionEngineSetupContext::ExecutionEngineSetupContext(ExecutionEngine& e):
275 void ExecutionEngineSetupContext::addPrePassCallback(PassCallback callback)
277 m_engine.addPrePassCallback(std::move(callback));
280 void ExecutionEngineSetupContext::addPostPassCallback(PassCallback callback)
282 m_engine.addPostPassCallback(std::move(callback));
285 void ExecutionEngineSetupContext::addExecutableDependency(const std::string& lazyPassName)
287 m_engine.addExecutableDependency(lazyPassName);
290 void ExecutionEngineSetupContext::addPassStage(const std::string& stageName)
292 m_engine.addPassStage(stageName);
295 void ExecutionEngineSetupContext::addPassStage(const std::string& stageName, const std::string& prevStage)
297 m_engine.addPassStage(stageName, prevStage);
300 ExecutionEngineSetupContext::StagesRange ExecutionEngineSetupContext::passStages() const
302 return m_engine.passStages();
305 void ExecutionEngine::CallbackList::call(const ExecutionEngine::PassDesc& desc, const passes::PassContext& context) const
307 for (auto& callback: callbacks)
309 ASSERT(nullptr != callback);
310 callback(desc, context);
314 bool ExecutionEngine::CallbackList::empty() const
316 return callbacks.empty();
319 IGraphListener* ExecutionEngine::LazyPasses::getListener() const
321 ASSERT((nullptr == last) == passes.empty());
325 ExecutionEngine::LazyPassWrapper* ExecutionEngine::LazyPasses::getPass(const std::string& name) const
327 ASSERT(!name.empty());
328 ASSERT(util::contains(passes, name));
329 auto it = passes.find(name);
330 auto ret = it->second.get();
331 ASSERT(nullptr != ret);
335 void ExecutionEngine::LazyPasses::reset()
343 ExecutionEngine::LazyPassWrapper::~LazyPassWrapper()
348 bool ExecutionEngine::LazyPassWrapper::isValid() const
353 bool ExecutionEngine::LazyPassWrapper::isFirst() const
355 return nullptr == m_prev;
358 ExecutionEngine::LazyPassWrapper::LazyPassWrapper(ExecutionEngine::LazyPassWrapper* prev):
364 void ExecutionEngine::LazyPassWrapper::reset()