1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
13 #include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
14 #include <cpp_interfaces/ie_task_with_stages.hpp>
15 #include <cpp_interfaces/ie_task_executor.hpp>
16 #include <cpp_interfaces/exception2status.hpp>
17 #include "ie_infer_async_request_thread_safe_internal.hpp"
19 namespace InferenceEngine {
22 * @class CallbackManager for wrapping calling of callback
24 class CallbackManager {
25 std::exception_ptr _requestException = nullptr;
26 StatusCode _requestStatus = OK;
27 IInferRequest::CompletionCallback _callback = nullptr;
28 bool _enabled = false;
29 IInferRequest::WeakPtr _publicInterface;
30 ITaskExecutor::Ptr _callbackExecutor;
33 using Ptr = std::shared_ptr<CallbackManager>;
35 explicit CallbackManager(const ITaskExecutor::Ptr &callbackExecutor) : _callbackExecutor(callbackExecutor) {}
37 void enableCallback() {
41 void disableCallback() {
45 bool isCallbackEnabled() { return _enabled && _callback != nullptr; }
47 void startTask(Task::Ptr task) { _callbackExecutor->startTask(task); }
50 _requestException = nullptr;
55 if (isCallbackEnabled()) {
56 auto requestPtr = _publicInterface.lock();
58 THROW_IE_EXCEPTION << "Failed to run callback: can't get pointer to request";
60 _callback(requestPtr, _requestStatus);
61 if (_requestException) std::rethrow_exception(_requestException);
65 void set_requestException(const std::exception_ptr &requestException) {
66 _requestException = requestException;
69 void set_requestStatus(StatusCode requestStatus) {
70 _requestStatus = requestStatus;
73 void set_callback(IInferRequest::CompletionCallback callback) {
78 void set_publicInterface(IInferRequest::Ptr publicInterface) {
79 _publicInterface = publicInterface;
83 class AsyncInferRequestThreadSafeDefault : public AsyncInferRequestThreadSafeInternal {
85 typedef std::shared_ptr<AsyncInferRequestThreadSafeDefault> Ptr;
87 explicit AsyncInferRequestThreadSafeDefault(InferRequestInternal::Ptr request,
88 const ITaskExecutor::Ptr &taskExecutor,
89 const TaskSynchronizer::Ptr &taskSynchronizer,
90 const ITaskExecutor::Ptr &callbackExecutor)
91 : _syncRequest(request),
92 _requestExecutor(taskExecutor),
93 _requestSynchronizer(taskSynchronizer),
94 _callbackManager(callbackExecutor) {
95 _syncTask = std::make_shared<Task>([this]() { _syncRequest->Infer(); });
96 _currentTask = _syncTask;
99 virtual ~AsyncInferRequestThreadSafeDefault() {
103 void waitAllAsyncTasks() {
105 while (!_listAsyncTasks.empty()) {
106 _listAsyncTasks.remove_if([this](StagedTask::Ptr task) -> bool {
107 auto sts = task->getStatus();
108 return !task->isOnWait() && (Task::Status::TS_DONE == sts || Task::Status::TS_ERROR == sts ||
109 Task::Status::TS_INITIAL == sts);
111 auto findIter = std::find_if(_listAsyncTasks.begin(), _listAsyncTasks.end(),
112 [this](StagedTask::Ptr task) { return !task->isOnWait(); });
113 if (findIter != _listAsyncTasks.end()) {
115 (*findIter)->wait(-1);
122 virtual void initNextAsyncTask() {
123 IE_PROFILING_AUTO_SCOPE(initNextAsyncTask)
124 // Most probably was called from callback (or when callback was started) or it was a sync task before, so new task is required
125 if (_currentTask->getStatus() == Task::Status::TS_POSTPONED || _currentTask == _syncTask) {
126 auto findIter = std::find_if(_listAsyncTasks.begin(), _listAsyncTasks.end(),
127 [this](StagedTask::Ptr task) -> bool {
128 return (!task->isOnWait()) && (task != _currentTask) &&
129 (Task::Status::TS_DONE == task->getStatus() ||
130 Task::Status::TS_ERROR == task->getStatus());
132 if (findIter == _listAsyncTasks.end()) {
133 _asyncTask = createAsyncRequestTask();
134 _listAsyncTasks.push_back(_asyncTask);
136 _asyncTask = *findIter;
139 _asyncTask->resetStages();
140 _currentTask = _asyncTask;
143 virtual void startAsyncTask() {
144 if (!_requestExecutor->startTask(_currentTask)) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
147 void StartAsync_ThreadUnsafe() override {
148 _syncRequest->checkBlobs();
149 _callbackManager.reset();
154 virtual void processAsyncTaskFailure(StagedTask::Ptr asyncTask) {
155 setIsRequestBusy(false);
156 auto requestException = std::current_exception();
157 // callback was set and hasn't been called, it must be called
158 if (_callbackManager.isCallbackEnabled() && asyncTask->getStage() >= 1) {
159 // jump to the "callback" stage because of happened error
160 while (asyncTask->getStage() != 1) asyncTask->stageDone();
161 _callbackManager.set_requestStatus(GENERAL_ERROR);
162 _callbackManager.set_requestException(requestException);
163 _callbackManager.startTask(asyncTask);
165 std::rethrow_exception(requestException);
169 virtual StagedTask::Ptr createAsyncRequestTask() {
170 return std::make_shared<StagedTask>([this]() {
171 auto asyncTaskCopy = _asyncTask;
173 switch (asyncTaskCopy->getStage()) {
175 _syncRequest->Infer();
176 asyncTaskCopy->stageDone();
177 if (_callbackManager.isCallbackEnabled()) {
178 _callbackManager.startTask(asyncTaskCopy);
180 asyncTaskCopy->stageDone();
185 setIsRequestBusy(false);
186 asyncTaskCopy->stageDone();
187 _callbackManager.runCallback();
194 processAsyncTaskFailure(asyncTaskCopy);
199 StatusCode Wait(int64_t millis_timeout) override {
200 auto taskCopy = _currentTask;
201 if (millis_timeout < IInferRequest::WaitMode::RESULT_READY) {
202 THROW_IE_EXCEPTION << PARAMETER_MISMATCH_str + "Timeout can't be less "
203 << IInferRequest::WaitMode::RESULT_READY
204 << " for InferRequest::Wait\n";
207 if (millis_timeout == IInferRequest::WaitMode::STATUS_ONLY) {
208 status = taskCopy->getStatus();
210 status = taskCopy->wait(millis_timeout);
211 setIsRequestBusy(false);
214 taskCopy->checkException();
215 return Task::TaskStatus2StatusCode(status);
218 void Infer_ThreadUnsafe() override {
219 _currentTask = _syncTask;
220 auto status = _currentTask->runWithSynchronizer(_requestSynchronizer);
221 if (status == Task::Status::TS_BUSY)
222 THROW_IE_EXCEPTION << "Internal error: AsyncInferRequestThreadSafeDefault failed to start sync task";
223 _currentTask->checkException();
226 void GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const override {
227 _syncRequest->GetPerformanceCounts(perfMap);
230 void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data) override {
231 _syncRequest->SetBlob(name, data);
234 void GetBlob_ThreadUnsafe(const char *name, Blob::Ptr &data) override {
235 _syncRequest->GetBlob(name, data);
238 void SetCompletionCallback_ThreadUnsafe(InferenceEngine::IInferRequest::CompletionCallback callback) override {
239 _callbackManager.set_callback(callback);
242 void GetUserData_ThreadUnsafe(void **data) override {
243 if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
247 void SetUserData_ThreadUnsafe(void *data) override {
251 void SetPointerToPublicInterface(InferenceEngine::IInferRequest::Ptr ptr) {
252 _callbackManager.set_publicInterface(ptr);
255 void SetBatch_ThreadUnsafe(int batch) override {
256 _syncRequest->SetBatch(batch);
260 ITaskExecutor::Ptr _requestExecutor;
261 TaskSynchronizer::Ptr _requestSynchronizer;
262 InferRequestInternal::Ptr _syncRequest;
264 StagedTask::Ptr _asyncTask;
265 Task::Ptr _currentTask;
266 std::list<StagedTask::Ptr> _listAsyncTasks;
268 CallbackManager _callbackManager;
271 } // namespace InferenceEngine