Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / cpp_interfaces / impl / ie_infer_async_request_thread_safe_internal.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #pragma once
6
7 #include <memory>
8 #include <map>
9 #include <string>
10 #include <cpp_interfaces/ie_task.hpp>
11 #include "cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp"
12 #include "cpp_interfaces/impl/ie_infer_request_internal.hpp"
13
14 namespace InferenceEngine {
15
16 /**
17  * @brief Wrapper of async request to support thread-safe execution.
18  */
19 class AsyncInferRequestThreadSafeInternal : public IAsyncInferRequestInternal {
20     bool _isRequestBusy = false;
21     std::mutex _isBusyMutex;
22
23 public:
24     typedef std::shared_ptr<AsyncInferRequestThreadSafeInternal> Ptr;
25
26     AsyncInferRequestThreadSafeInternal() {
27         setIsRequestBusy(false);
28     }
29
30 protected:
31     virtual bool isRequestBusy() const {
32         return _isRequestBusy;
33     }
34
35     virtual void setIsRequestBusy(bool isBusy) {
36         std::unique_lock<std::mutex> lock(_isBusyMutex);
37         _isRequestBusy = isBusy;
38     }
39
40 public:
41     void StartAsync() override {
42         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
43         setIsRequestBusy(true);
44         try {
45             StartAsync_ThreadUnsafe();
46         } catch (...) {
47             setIsRequestBusy(false);
48             std::rethrow_exception(std::current_exception());
49         }
50     }
51
52     void GetUserData(void **data) override {
53         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
54         GetUserData_ThreadUnsafe(data);
55     }
56
57     void SetUserData(void *data) override {
58         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
59         SetUserData_ThreadUnsafe(data);
60     }
61
62     void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
63         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
64         SetCompletionCallback_ThreadUnsafe(callback);
65     }
66
67     void Infer() override {
68         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
69         setIsRequestBusy(true);
70         try {
71             Infer_ThreadUnsafe();
72         } catch (...) {
73             setIsRequestBusy(false);
74             std::rethrow_exception(std::current_exception());
75         }
76         setIsRequestBusy(false);
77     }
78
79     void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const override {
80         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
81         GetPerformanceCounts_ThreadUnsafe(perfMap);
82     }
83
84     void SetBlob(const char *name, const Blob::Ptr &data) override {
85         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
86         SetBlob_ThreadUnsafe(name, data);
87     }
88
89     void GetBlob(const char *name, Blob::Ptr &data) override {
90         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
91         GetBlob_ThreadUnsafe(name, data);
92     }
93
94     void SetBatch(int batch) override {
95         if (isRequestBusy()) THROW_IE_EXCEPTION << REQUEST_BUSY_str;
96         SetBatch_ThreadUnsafe(batch);
97     };
98
99     /**
100      * @brief methods with _ThreadUnsafe prefix are to implement in plugins
101      * or in default wrapper (e.g. AsyncInferRequestThreadSafeDefault)
102      */
103     virtual void StartAsync_ThreadUnsafe() = 0;
104
105     virtual void GetUserData_ThreadUnsafe(void **data) = 0;
106
107     virtual void SetUserData_ThreadUnsafe(void *data) = 0;
108
109     virtual void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) = 0;
110
111     virtual void Infer_ThreadUnsafe() = 0;
112
113     virtual void
114     GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const = 0;
115
116     virtual void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data) = 0;
117
118     virtual void GetBlob_ThreadUnsafe(const char *name, Blob::Ptr &data) = 0;
119
120     virtual void SetBatch_ThreadUnsafe(int batch) = 0;
121 };
122
123 }  // namespace InferenceEngine