Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / cpp_interfaces / impl / ie_infer_async_request_thread_safe_default.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <map>
9 #include <list>
10 #include <string>
11 #include <mutex>
12 #include <exception>
13 #include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
14 #include <cpp_interfaces/ie_task_with_stages.hpp>
15 #include <cpp_interfaces/ie_task_executor.hpp>
16 #include <cpp_interfaces/exception2status.hpp>
17 #include "ie_infer_async_request_thread_safe_internal.hpp"
18
19 namespace InferenceEngine {
20
21 /**
22  * @class CallbackManager for wrapping calling of callback
23  */
24 class CallbackManager {
25     std::exception_ptr _requestException = nullptr;
26     StatusCode _requestStatus = OK;
27     IInferRequest::CompletionCallback _callback = nullptr;
28     bool _enabled = false;
29     IInferRequest::WeakPtr _publicInterface;
30     ITaskExecutor::Ptr _callbackExecutor;
31
32 public:
33     using Ptr = std::shared_ptr<CallbackManager>;
34
35     explicit CallbackManager(const ITaskExecutor::Ptr &callbackExecutor) : _callbackExecutor(callbackExecutor) {}
36
37     void enableCallback() {
38         _enabled = true;
39     }
40
41     void disableCallback() {
42         _enabled = false;
43     }
44
45     bool isCallbackEnabled() { return _enabled && _callback != nullptr; }
46
47     void startTask(Task::Ptr task) { _callbackExecutor->startTask(task); }
48
49     void reset() {
50         _requestException = nullptr;
51         _requestStatus = OK;
52     }
53
54     void runCallback() {
55         if (isCallbackEnabled()) {
56             auto requestPtr = _publicInterface.lock();
57             if (!requestPtr) {
58                 THROW_IE_EXCEPTION << "Failed to run callback: can't get pointer to request";
59             }
60             _callback(requestPtr, _requestStatus);
61             if (_requestException) std::rethrow_exception(_requestException);
62         }
63     }
64
65     void set_requestException(const std::exception_ptr &requestException) {
66         _requestException = requestException;
67     }
68
69     void set_requestStatus(StatusCode requestStatus) {
70         _requestStatus = requestStatus;
71     }
72
73     void set_callback(IInferRequest::CompletionCallback callback) {
74         enableCallback();
75         _callback = callback;
76     }
77
78     void set_publicInterface(IInferRequest::Ptr publicInterface) {
79         _publicInterface = publicInterface;
80     }
81 };
82
83 class AsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeInternal {
84 public:
85     typedef std::shared_ptr<AsyncInferRequestThreadSafeDefault> Ptr;
86
87     explicit AsyncInferRequestThreadSafeDefault(InferRequestInternal::Ptr request,
88                                                 const ITaskExecutor::Ptr &taskExecutor,
89                                                 const TaskSynchronizer::Ptr &taskSynchronizer,
90                                                 const ITaskExecutor::Ptr &callbackExecutor)
91             : _syncRequest(request),
92               _requestExecutor(taskExecutor),
93               _requestSynchronizer(taskSynchronizer),
94               _callbackManager(callbackExecutor) {
95         _syncTask = std::make_shared<Task>([this]() { _syncRequest->Infer(); });
96         _currentTask = _syncTask;
97     }
98
99     virtual ~AsyncInferRequestThreadSafeDefault() {
100         waitAllAsyncTasks();
101     }
102
103     void waitAllAsyncTasks() {
104         try {
105             while (!_listAsyncTasks.empty()) {
106                 _listAsyncTasks.remove_if([this](StagedTask::Ptr task) -> bool {
107                     auto sts = task->getStatus();
108                     return !task->isOnWait() && (Task::Status::TS_DONE == sts || Task::Status::TS_ERROR == sts ||
109                                                  Task::Status::TS_INITIAL == sts);
110                 });
111                 auto findIter = std::find_if(_listAsyncTasks.begin(), _listAsyncTasks.end(),
112                                              [this](StagedTask::Ptr task) { return !task->isOnWait(); });
113                 if (findIter != _listAsyncTasks.end()) {
114                     try {
115                         (*findIter)->wait(-1);
116                     } catch (...) {}
117                 }
118             }
119         } catch (...) {}
120     }
121
122     virtual void initNextAsyncTask() {
123         IE_PROFILING_AUTO_SCOPE(initNextAsyncTask)
124         // Most probably was called from callback (or when callback was started) or it was a sync task before, so new task is required
125         if (_currentTask->getStatus() == Task::Status::TS_POSTPONED || _currentTask == _syncTask) {
126             auto findIter = std::find_if(_listAsyncTasks.begin(), _listAsyncTasks.end(),
127                                          [this](StagedTask::Ptr task) -> bool {
128                                              return (!task->isOnWait()) && (task != _currentTask) &&
129                                                     (Task::Status::TS_DONE == task->getStatus() ||
130                                                      Task::Status::TS_ERROR == task->getStatus());
131                                          });
132             if (findIter == _listAsyncTasks.end()) {
133                 _asyncTask = createAsyncRequestTask();
134                 _listAsyncTasks.push_back(_asyncTask);
135             } else {
136                 _asyncTask = *findIter;
137             }
138         }
139         _asyncTask->resetStages();
140         _currentTask = _asyncTask;
141     }
142
143     virtual void startAsyncTask() {
144         if (!_requestExecutor->startTask(_currentTask)) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
145     }
146
147     void StartAsync_ThreadUnsafe() override {
148         _syncRequest->checkBlobs();
149         _callbackManager.reset();
150         initNextAsyncTask();
151         startAsyncTask();
152     }
153
154     virtual void processAsyncTaskFailure(StagedTask::Ptr asyncTask) {
155         setIsRequestBusy(false);
156         auto requestException = std::current_exception();
157         // callback was set and hasn't been called, it must be called
158         if (_callbackManager.isCallbackEnabled() && asyncTask->getStage() >= 1) {
159             // jump to the "callback" stage because of happened error
160             while (asyncTask->getStage() != 1) asyncTask->stageDone();
161             _callbackManager.set_requestStatus(GENERAL_ERROR);
162             _callbackManager.set_requestException(requestException);
163             _callbackManager.startTask(asyncTask);
164         } else {
165             std::rethrow_exception(requestException);
166         }
167     }
168
169     virtual StagedTask::Ptr createAsyncRequestTask() {
170         return std::make_shared<StagedTask>([this]() {
171             auto asyncTaskCopy = _asyncTask;
172             try {
173                 switch (asyncTaskCopy->getStage()) {
174                     case 2: {
175                         _syncRequest->Infer();
176                         asyncTaskCopy->stageDone();
177                         if (_callbackManager.isCallbackEnabled()) {
178                             _callbackManager.startTask(asyncTaskCopy);
179                         } else {
180                             asyncTaskCopy->stageDone();
181                         }
182                     }
183                         break;
184                     case 1: {
185                         setIsRequestBusy(false);
186                         asyncTaskCopy->stageDone();
187                         _callbackManager.runCallback();
188                     }
189                         break;
190                     default:
191                         break;
192                 }
193             } catch (...) {
194                 processAsyncTaskFailure(asyncTaskCopy);
195             }
196         }, 2);
197     }
198
199     StatusCode Wait(int64_t millis_timeout) override {
200         auto taskCopy = _currentTask;
201         if (millis_timeout < IInferRequest::WaitMode::RESULT_READY) {
202             THROW_IE_EXCEPTION << PARAMETER_MISMATCH_str + "Timeout can't be less "
203                                << IInferRequest::WaitMode::RESULT_READY
204                                << " for InferRequest::Wait\n";
205         }
206         Task::Status status;
207         if (millis_timeout == IInferRequest::WaitMode::STATUS_ONLY) {
208             status = taskCopy->getStatus();
209         } else {
210             status = taskCopy->wait(millis_timeout);
211             setIsRequestBusy(false);
212         }
213
214         taskCopy->checkException();
215         return Task::TaskStatus2StatusCode(status);
216     }
217
218     void Infer_ThreadUnsafe() override {
219         _currentTask = _syncTask;
220         auto status = _currentTask->runWithSynchronizer(_requestSynchronizer);
221         if (status == Task::Status::TS_BUSY)
222             THROW_IE_EXCEPTION << "Internal error: AsyncInferRequestThreadSafeDefault failed to start sync task";
223         _currentTask->checkException();
224     }
225
226     void GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const override {
227         _syncRequest->GetPerformanceCounts(perfMap);
228     }
229
230     void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data) override {
231         _syncRequest->SetBlob(name, data);
232     }
233
234     void GetBlob_ThreadUnsafe(const char *name, Blob::Ptr &data) override {
235         _syncRequest->GetBlob(name, data);
236     }
237
238     void SetCompletionCallback_ThreadUnsafe(InferenceEngine::IInferRequest::CompletionCallback callback) override {
239         _callbackManager.set_callback(callback);
240     }
241
242     void GetUserData_ThreadUnsafe(void **data) override {
243         if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
244         *data = _userData;
245     }
246
247     void SetUserData_ThreadUnsafe(void *data) override {
248         _userData = data;
249     }
250
251     void SetPointerToPublicInterface(InferenceEngine::IInferRequest::Ptr ptr) {
252         _callbackManager.set_publicInterface(ptr);
253     }
254
255     void SetBatch_ThreadUnsafe(int batch) override {
256         _syncRequest->SetBatch(batch);
257     }
258
259 protected:
260     ITaskExecutor::Ptr _requestExecutor;
261     TaskSynchronizer::Ptr _requestSynchronizer;
262     InferRequestInternal::Ptr _syncRequest;
263     Task::Ptr _syncTask;
264     StagedTask::Ptr _asyncTask;
265     Task::Ptr _currentTask;
266     std::list<StagedTask::Ptr> _listAsyncTasks;
267     void *_userData;
268     CallbackManager _callbackManager;
269 };
270
271 }  // namespace InferenceEngine