Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / include / cpp / ie_infer_request.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief A header file that provides wrapper classes for infer requests and callbacks.
7  * @file ie_infer_request.hpp
8  */
9 #pragma once
10
11 #include <memory>
12 #include <string>
13 #include <map>
14 #include "ie_iinfer_request.hpp"
15 #include "details/ie_exception_conversion.hpp"
16
17 namespace InferenceEngine {
18
19 namespace details {
20
21 class ICompletionCallbackWrapper {
22 public:
23     virtual ~ICompletionCallbackWrapper() = default;
24
25     virtual void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept = 0;
26 };
27
28 template<class T>
29 class CompletionCallbackWrapper : public ICompletionCallbackWrapper {
30     T lambda;
31 public:
32     explicit CompletionCallbackWrapper(const T &lambda) : lambda(lambda) {}
33
34     void call(InferenceEngine::IInferRequest::Ptr /*request*/,
35               InferenceEngine::StatusCode /*code*/) const noexcept override {
36         lambda();
37     }
38 };
39
40 template<>
41 class CompletionCallbackWrapper<IInferRequest::CompletionCallback> : public ICompletionCallbackWrapper {
42     IInferRequest::CompletionCallback callBack;
43 public:
44     explicit CompletionCallbackWrapper(const IInferRequest::CompletionCallback &callBack) : callBack(callBack) {}
45
46     void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept override {
47         callBack(request, code);
48     }
49 };
50
51 }  // namespace details
52
53 /**
54  * @brief This class is a wrapper of IInferRequest to provide setters/getters
55  * of input/output which operates with BlobMaps.
56  * It can throw exceptions safely for the application, where it is properly handled.
57  */
58 class InferRequest {
59     IInferRequest::Ptr actual;
60     std::shared_ptr<details::ICompletionCallbackWrapper> callback;
61
62     static void callWrapper(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) {
63         details::ICompletionCallbackWrapper *pWrapper = nullptr;
64         ResponseDesc dsc;
65         request->GetUserData(reinterpret_cast<void**>(&pWrapper), &dsc);
66         pWrapper->call(request, code);
67     }
68
69 public:
70     InferRequest() = default;
71
72     /**
73      * @brief Sets input/output data to infer
74      * @note: Memory allocation does not happen
75      * @param name Name of input or output blob.
76      * @param data Reference to input or output blob. The type of a blob must match the network input precision and size.
77      */
78     void SetBlob(const std::string &name, const Blob::Ptr &data) {
79         CALL_STATUS_FNC(SetBlob, name.c_str(), data);
80     }
81
82     /**
83      * @brief Wraps original method
84      * IInferRequest::GetBlob
85      */
86     Blob::Ptr GetBlob(const std::string &name) {
87         Blob::Ptr data;
88         CALL_STATUS_FNC(GetBlob, name.c_str(), data);
89         std::string error = "Internal error: blob with name `" + name + "` is not allocated!";
90         auto blobPtr = data.get();
91         if (blobPtr == nullptr) THROW_IE_EXCEPTION << error;
92         if (blobPtr->buffer() == nullptr) THROW_IE_EXCEPTION << error;
93         return data;
94     }
95
96     /**
97      * @brief Wraps original method
98      * IInferRequest::Infer
99      */
100     void Infer() {
101         CALL_STATUS_FNC_NO_ARGS(Infer);
102     }
103
104     /**
105      * @brief Wraps original method
106      * IInferRequest::GetPerformanceCounts
107      */
108     std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const {
109         std::map<std::string, InferenceEngineProfileInfo> perfMap;
110         CALL_STATUS_FNC(GetPerformanceCounts, perfMap);
111         return perfMap;
112     }
113
114     /**
115      * @brief Sets input data to infer
116      * @note: Memory allocation doesn't happen
117      * @param inputs - a reference to a map of input blobs accessed by input names.
118      *        The type of Blob must correspond to the network input precision and size.
119      */
120     void SetInput(const BlobMap &inputs) {
121         for (auto &&input : inputs) {
122             CALL_STATUS_FNC(SetBlob, input.first.c_str(), input.second);
123         }
124     }
125
126     /**
127      * @brief Sets data that will contain result of the inference
128      * @note: Memory allocation doesn't happen
129      * @param results - a reference to a map of result blobs accessed by output names.
130      *        The type of Blob must correspond to the network output precision and size.
131      */
132     void SetOutput(const BlobMap &results) {
133         for (auto &&result : results) {
134             CALL_STATUS_FNC(SetBlob, result.first.c_str(), result.second);
135         }
136     }
137
138     /**
139     * @brief Sets new batch size when dynamic batching is enabled in executable network that created this request.
140     * @param batch new batch size to be used by all the following inference calls for this request.
141     */
142     void SetBatch(const int batch) {
143         CALL_STATUS_FNC(SetBatch, batch);
144     }
145
146     /**
147      * constructs InferRequest from initialised shared_pointer
148      * @param actual
149      */
150     explicit InferRequest(IInferRequest::Ptr request) : actual(request) {}
151
152     /**
153      * @brief Start inference of specified input(s) in asynchronous mode
154      * @note: It returns immediately. Inference starts also immediately.
155      */
156     void StartAsync() {
157         CALL_STATUS_FNC_NO_ARGS(StartAsync);
158     }
159
160     /**
161      * @brief Wraps original method
162      * IInferRequest::Wait
163      */
164     StatusCode Wait(int64_t millis_timeout) {
165         return actual->Wait(millis_timeout, nullptr);
166     }
167
168     /**
169      * @brief Wraps original method
170      * IInferRequest::SetCompletionCallback
171      *
172      * @param callbackToSet Lambda callback object which will be called on processing finish.
173      */
174     template <class T>
175     void SetCompletionCallback(const T & callbackToSet) {
176         callback.reset(new details::CompletionCallbackWrapper<T>(callbackToSet));
177         CALL_STATUS_FNC(SetUserData, callback.get());
178         actual->SetCompletionCallback(callWrapper);
179     }
180
181     /**
182      * @brief  IInferRequest pointer to be used directly in CreateInferRequest functions
183      */
184     operator IInferRequest::Ptr &() {
185         return actual;
186     }
187
188     bool operator!() const noexcept {
189         return !actual;
190     }
191
192     explicit operator bool() const noexcept {
193         return !!actual;
194     }
195
196     using Ptr = std::shared_ptr<InferRequest>;
197 };
198
199 namespace details {
200
201 template<>
202 class CompletionCallbackWrapper<std::function<void(InferRequest, StatusCode)>>
203         : public ICompletionCallbackWrapper {
204     std::function<void(InferRequest, StatusCode)> lambda;
205 public:
206     explicit CompletionCallbackWrapper(const std::function<void(InferRequest, InferenceEngine::StatusCode)> &lambda)
207             : lambda(lambda) {}
208
209     void call(InferenceEngine::IInferRequest::Ptr request,
210               InferenceEngine::StatusCode code) const noexcept override {
211         lambda(InferRequest(request), code);
212     }
213 };
214
215 }  // namespace details
216 }  // namespace InferenceEngine