8aae70869c761ea3c86587eb72fa4ec9a5e7db16
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / cpp_interfaces / impl / ie_infer_async_request_thread_safe_internal.hpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #pragma once
7
8 #include <memory>
9 #include <map>
10 #include <string>
11 #include <cpp_interfaces/ie_task.hpp>
12 #include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
13 #include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
14
15 namespace InferenceEngine {
16
17 /**
18  * @brief Wrapper of async request to support thread-safe execution.
19  */
20 class AsyncInferRequestThreadSafeInternal : public IAsyncInferRequestInternal {
21     bool _isRequestBusy = false;
22     std::mutex _isBusyMutex;
23
24 public:
25     typedef std::shared_ptr<AsyncInferRequestThreadSafeInternal> Ptr;
26
27     AsyncInferRequestThreadSafeInternal() {
28         setIsRequestBusy(false);
29     }
30
31 protected:
32     virtual bool isRequestBusy() const {
33         return _isRequestBusy;
34     }
35
36     virtual void setIsRequestBusy(bool isBusy) {
37         std::unique_lock<std::mutex> lock(_isBusyMutex);
38         _isRequestBusy = isBusy;
39     }
40
41 public:
42     void StartAsync() override {
43         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
44         setIsRequestBusy(true);
45         try {
46             StartAsync_ThreadUnsafe();
47         } catch (...) {
48             setIsRequestBusy(false);
49             std::rethrow_exception(std::current_exception());
50         }
51     }
52
53     void GetUserData(void **data) override {
54         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
55         GetUserData_ThreadUnsafe(data);
56     }
57
58     void SetUserData(void *data) override {
59         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
60         SetUserData_ThreadUnsafe(data);
61     }
62
63     void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
64         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
65         SetCompletionCallback_ThreadUnsafe(callback);
66     }
67
68     void Infer() override {
69         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
70         setIsRequestBusy(true);
71         try {
72             Infer_ThreadUnsafe();
73         } catch (...) {
74             setIsRequestBusy(false);
75             std::rethrow_exception(std::current_exception());
76         }
77         setIsRequestBusy(false);
78     }
79
80     void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const override {
81         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
82         GetPerformanceCounts_ThreadUnsafe(perfMap);
83     }
84
85     void SetBlob(const char *name, const Blob::Ptr &data) override {
86         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
87         SetBlob_ThreadUnsafe(name, data);
88     }
89
90     void GetBlob(const char *name, Blob::Ptr &data) override {
91         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
92         GetBlob_ThreadUnsafe(name, data);
93     }
94
95     void SetBatch(int batch) override {
96         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
97         SetBatch_ThreadUnsafe(batch);
98     };
99
100     /**
101      * @brief methods with _ThreadUnsafe prefix are to implement in plugins
102      * or in default wrapper (e.g. AsyncInferRequestThreadSafeDefault)
103      */
104     virtual void StartAsync_ThreadUnsafe() = 0;
105
106     virtual void GetUserData_ThreadUnsafe(void **data) = 0;
107
108     virtual void SetUserData_ThreadUnsafe(void *data) = 0;
109
110     virtual void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) = 0;
111
112     virtual void Infer_ThreadUnsafe() = 0;
113
114     virtual void
115     GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const = 0;
116
117     virtual void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data) = 0;
118
119     virtual void GetBlob_ThreadUnsafe(const char *name, Blob::Ptr &data) = 0;
120
121     virtual void SetBatch_ThreadUnsafe(int batch) = 0;
122 };
123
124 }  // namespace InferenceEngine