Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / ade / ade / source / execution_engine.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "execution_engine/execution_engine.hpp"
7 #include "execution_engine/backend.hpp"
8 #include "execution_engine/executable.hpp"
9
10 #include "graph.hpp"
11
12 #include "util/assert.hpp"
13 #include "util/range.hpp"
14 #include "util/checked_cast.hpp"
15
16 namespace ade
17 {
18
19 class ExecutableImpl final : public Executable
20 {
21     util::any m_dummy_opaque;
22
23 public:
24     void addExec(std::unique_ptr<BackendExecutable>&& exec);
25
26     virtual void run() override;
27     virtual void run(util::any &opaque) override;
28
29     virtual void runAsync() override;
30     virtual void runAsync(util::any &opaque) override;
31     virtual void wait() override;
32
33 private:
34     std::unique_ptr<BackendExecutable> mainExec;
35     std::vector<std::unique_ptr<BackendExecutable>> execs;
36 };
37
38 void BackendExecutable::run(util::any &)
39 {
40     // Default implementation calls run() (backward compatibility)
41     run();
42 }
43
44 void BackendExecutable::runAsync(util::any &)
45 {
46     // Default implementation calls runAsync() (backward compatibility)
47     runAsync();
48 }
49
50 ExecutionEngine::ExecutionEngine()
51 {
52
53 }
54
55 ExecutionEngine::~ExecutionEngine()
56 {
57
58 }
59
60 void ExecutionEngine::addPrePassCallback(ExecutionEngine::PassCallback callback)
61 {
62     ASSERT(nullptr != callback);
63     m_prePassCallbacks.callbacks.emplace_back(std::move(callback));
64 }
65
66 void ExecutionEngine::addPostPassCallback(ExecutionEngine::PassCallback callback)
67 {
68     ASSERT(nullptr != callback);
69     m_postPassCallbacks.callbacks.emplace_back(std::move(callback));
70 }
71
72 void ExecutionEngine::addBackend(std::unique_ptr<ExecutionBackend>&& backend)
73 {
74     ASSERT(nullptr != backend);
75     ASSERT(m_backends.end() == std::find(m_backends.begin(), m_backends.end(), backend));
76     m_backends.emplace_back(std::move(backend));
77 }
78
79 void ExecutionEngine::setupBackends()
80 {
81     ExecutionEngineSetupContext context(*this);
82     for (auto& b: m_backends)
83     {
84         b->setupExecutionEngine(context);
85     }
86 }
87
88 namespace
89 {
90 struct GraphListenerSetter final
91 {
92     Graph& graph;
93     IGraphListener* listener = nullptr;
94     GraphListenerSetter(Graph& gr, IGraphListener* l):
95         graph(gr), listener(l)
96     {
97         ASSERT(nullptr == graph.getListener());
98         graph.setListener(listener);
99     }
100     ~GraphListenerSetter()
101     {
102         ASSERT(listener == graph.getListener());
103         graph.setListener(nullptr);
104     }
105
106     GraphListenerSetter(const GraphListenerSetter&) = delete;
107     GraphListenerSetter& operator=(const GraphListenerSetter&) = delete;
108 };
109 }
110
111 void ExecutionEngine::runPasses(Graph& graph)
112 {
113     m_lazyPasses.reset();
114     GraphListenerSetter setter(graph, m_lazyPasses.getListener());
115     passes::PassContext context{graph};
116     m_passManager.run(context);
117     for (auto& str: m_executableDependencies)
118     {
119         ASSERT(!str.empty());
120         auto pass = m_lazyPasses.getPass(str);
121         ASSERT(nullptr != pass);
122         pass->process(context);
123     }
124 }
125
126 std::unique_ptr<Executable> ExecutionEngine::createExecutable(const Graph& graph)
127 {
128     std::unique_ptr<ExecutableImpl> ret;
129     for (auto& b : m_backends)
130     {
131         std::unique_ptr<BackendExecutable> bexec(b->createExecutable(graph));
132         if (nullptr != bexec)
133         {
134             if (nullptr == ret)
135             {
136                 ret.reset(new ExecutableImpl);
137             }
138             ret->addExec(std::move(bexec));
139         }
140     }
141
142     return std::move(ret);
143 }
144
145 void ExecutionEngine::addExecutableDependency(const std::string& lazyPassName)
146 {
147     ASSERT(!lazyPassName.empty());
148     ASSERT(m_lazyPasses.getPass(lazyPassName) != nullptr);
149     m_executableDependencies.emplace(lazyPassName);
150 }
151
152 void ExecutionEngine::addPassStage(const std::string& stageName)
153 {
154     ASSERT(!stageName.empty());
155     m_passManager.addStage(stageName);
156 }
157
158 void ExecutionEngine::addPassStage(const std::string& stageName, const std::string& prevStage)
159 {
160     ASSERT(!stageName.empty());
161     ASSERT(!prevStage.empty());
162     m_passManager.addStage(stageName, prevStage);
163 }
164
165 ExecutionEngine::StagesRange ExecutionEngine::passStages() const
166 {
167     return util::map<PassMapper>(m_passManager.stages());
168 }
169
170 void ExecutionEngine::prePass(const PassDesc& desc,
171                               const passes::PassContext& context)
172 {
173     m_prePassCallbacks.call(desc, context);
174 }
175
176 void ExecutionEngine::postPass(const PassDesc& desc,
177                                const passes::PassContext& context)
178 {
179     m_postPassCallbacks.call(desc, context);
180 }
181
182 void ExecutableImpl::addExec(std::unique_ptr<BackendExecutable>&& exec)
183 {
184     ASSERT(nullptr != exec);
185     if (nullptr == mainExec)
186     {
187         mainExec = std::move(exec);
188     }
189     else
190     {
191         execs.emplace_back(std::move(exec));
192     }
193 }
194
195 struct ExecExceptionHandler
196 {
197     size_t passedCount = 0;
198     std::vector<std::unique_ptr<BackendExecutable>> &handledVector;
199     ExecExceptionHandler(std::vector<std::unique_ptr<BackendExecutable>> &execs) : handledVector(execs) {}
200     ~ExecExceptionHandler()
201     {
202         ASSERT(handledVector.size() >= passedCount);
203         auto count = util::checked_cast<int>(handledVector.size() - passedCount);
204         for (auto i = util::checked_cast<int>(handledVector.size()) - 1;
205              i >= count;
206              i--)
207         {
208             handledVector[i]->cancel();
209         }
210     }
211 };
212
213 void ExecutableImpl::run()
214 {
215     // Since run() takes a modifiable `any`, reset it before the run
216     m_dummy_opaque = util::any();
217     run(m_dummy_opaque);
218 }
219
220 void ExecutableImpl::run(util::any &opaque)
221 {
222     ASSERT(nullptr != mainExec);
223     ExecExceptionHandler handler(execs);
224     for (auto& e: util::toRangeReverse(execs))
225     {
226         e->runAsync(opaque);
227         handler.passedCount++;
228     }
229
230     mainExec->run(opaque);
231
232     for (auto& e: util::toRange(execs))
233     {
234         handler.passedCount--;
235         e->wait();
236     }
237 }
238
239 void ExecutableImpl::runAsync()
240 {
241     // Since runAsync() takes a modifiable `any`, reset it before the run
242     m_dummy_opaque = util::any();
243     runAsync(m_dummy_opaque);
244 }
245
246 void ExecutableImpl::runAsync(util::any &opaque)
247 {
248     ASSERT(nullptr != mainExec);
249     for (auto& e: util::toRangeReverse(execs))
250     {
251         e->runAsync(opaque);
252     }
253     mainExec->runAsync(opaque);
254 }
255
256 void ExecutableImpl::wait()
257 {
258     ASSERT(nullptr != mainExec);
259     ExecExceptionHandler handler(execs);
260     handler.passedCount = util::checked_cast<decltype(handler.passedCount)>(execs.size());
261     mainExec->wait();
262     for (auto& e: util::toRange(execs))
263     {
264         handler.passedCount--;
265         e->wait();
266     }
267 }
268
269 ExecutionEngineSetupContext::ExecutionEngineSetupContext(ExecutionEngine& e):
270     m_engine(e)
271 {
272
273 }
274
275 void ExecutionEngineSetupContext::addPrePassCallback(PassCallback callback)
276 {
277     m_engine.addPrePassCallback(std::move(callback));
278 }
279
280 void ExecutionEngineSetupContext::addPostPassCallback(PassCallback callback)
281 {
282     m_engine.addPostPassCallback(std::move(callback));
283 }
284
285 void ExecutionEngineSetupContext::addExecutableDependency(const std::string& lazyPassName)
286 {
287     m_engine.addExecutableDependency(lazyPassName);
288 }
289
290 void ExecutionEngineSetupContext::addPassStage(const std::string& stageName)
291 {
292     m_engine.addPassStage(stageName);
293 }
294
295 void ExecutionEngineSetupContext::addPassStage(const std::string& stageName, const std::string& prevStage)
296 {
297     m_engine.addPassStage(stageName, prevStage);
298 }
299
300 ExecutionEngineSetupContext::StagesRange ExecutionEngineSetupContext::passStages() const
301 {
302     return m_engine.passStages();
303 }
304
305 void ExecutionEngine::CallbackList::call(const ExecutionEngine::PassDesc& desc, const passes::PassContext& context) const
306 {
307     for (auto& callback: callbacks)
308     {
309         ASSERT(nullptr != callback);
310         callback(desc, context);
311     }
312 }
313
314 bool ExecutionEngine::CallbackList::empty() const
315 {
316     return callbacks.empty();
317 }
318
319 IGraphListener* ExecutionEngine::LazyPasses::getListener() const
320 {
321     ASSERT((nullptr == last) == passes.empty());
322     return last;
323 }
324
325 ExecutionEngine::LazyPassWrapper* ExecutionEngine::LazyPasses::getPass(const std::string& name) const
326 {
327     ASSERT(!name.empty());
328     ASSERT(util::contains(passes, name));
329     auto it = passes.find(name);
330     auto ret = it->second.get();
331     ASSERT(nullptr != ret);
332     return ret;
333 }
334
335 void ExecutionEngine::LazyPasses::reset()
336 {
337     if (nullptr != last)
338     {
339         last->reset();
340     }
341 }
342
343 ExecutionEngine::LazyPassWrapper::~LazyPassWrapper()
344 {
345
346 }
347
348 bool ExecutionEngine::LazyPassWrapper::isValid() const
349 {
350     return m_valid;
351 }
352
353 bool ExecutionEngine::LazyPassWrapper::isFirst() const
354 {
355     return nullptr == m_prev;
356 }
357
358 ExecutionEngine::LazyPassWrapper::LazyPassWrapper(ExecutionEngine::LazyPassWrapper* prev):
359     m_prev(prev)
360 {
361
362 }
363
364 void ExecutionEngine::LazyPassWrapper::reset()
365 {
366     m_valid = false;
367     if (!isFirst())
368     {
369         m_prev->reset();
370     }
371 }
372
373 }