1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 * @brief A header file that provides wrapper classes for infer requests and callbacks.
7 * @file ie_infer_request.hpp
14 #include "ie_iinfer_request.hpp"
15 #include "details/ie_exception_conversion.hpp"
17 namespace InferenceEngine {
21 class ICompletionCallbackWrapper {
23 virtual ~ICompletionCallbackWrapper() = default;
25 virtual void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept = 0;
29 class CompletionCallbackWrapper : public ICompletionCallbackWrapper {
32 explicit CompletionCallbackWrapper(const T &lambda) : lambda(lambda) {}
34 void call(InferenceEngine::IInferRequest::Ptr /*request*/,
35 InferenceEngine::StatusCode /*code*/) const noexcept override {
41 class CompletionCallbackWrapper<IInferRequest::CompletionCallback> : public ICompletionCallbackWrapper {
42 IInferRequest::CompletionCallback callBack;
44 explicit CompletionCallbackWrapper(const IInferRequest::CompletionCallback &callBack) : callBack(callBack) {}
46 void call(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) const noexcept override {
47 callBack(request, code);
51 } // namespace details
54 * @brief This class is a wrapper of IInferRequest to provide setters/getters
55 * of input/output which operates with BlobMaps.
56 * It can throw exceptions safely for the application, where it is properly handled.
59 IInferRequest::Ptr actual;
60 std::shared_ptr<details::ICompletionCallbackWrapper> callback;
62 static void callWrapper(InferenceEngine::IInferRequest::Ptr request, InferenceEngine::StatusCode code) {
63 details::ICompletionCallbackWrapper *pWrapper = nullptr;
65 request->GetUserData(reinterpret_cast<void**>(&pWrapper), &dsc);
66 pWrapper->call(request, code);
70 InferRequest() = default;
73 * @brief Sets input/output data to infer
74 * @note: Memory allocation does not happen
75 * @param name Name of input or output blob.
76 * @param data Reference to input or output blob. The type of a blob must match the network input precision and size.
78 void SetBlob(const std::string &name, const Blob::Ptr &data) {
79 CALL_STATUS_FNC(SetBlob, name.c_str(), data);
83 * @brief Wraps original method
84 * IInferRequest::GetBlob
86 Blob::Ptr GetBlob(const std::string &name) {
88 CALL_STATUS_FNC(GetBlob, name.c_str(), data);
89 std::string error = "Internal error: blob with name `" + name + "` is not allocated!";
90 auto blobPtr = data.get();
91 if (blobPtr == nullptr) THROW_IE_EXCEPTION << error;
92 if (blobPtr->buffer() == nullptr) THROW_IE_EXCEPTION << error;
97 * @brief Wraps original method
98 * IInferRequest::Infer
101 CALL_STATUS_FNC_NO_ARGS(Infer);
105 * @brief Wraps original method
106 * IInferRequest::GetPerformanceCounts
108 std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const {
109 std::map<std::string, InferenceEngineProfileInfo> perfMap;
110 CALL_STATUS_FNC(GetPerformanceCounts, perfMap);
115 * @brief Sets input data to infer
116 * @note: Memory allocation doesn't happen
117 * @param inputs - a reference to a map of input blobs accessed by input names.
118 * The type of Blob must correspond to the network input precision and size.
120 void SetInput(const BlobMap &inputs) {
121 for (auto &&input : inputs) {
122 CALL_STATUS_FNC(SetBlob, input.first.c_str(), input.second);
127 * @brief Sets data that will contain result of the inference
128 * @note: Memory allocation doesn't happen
129 * @param results - a reference to a map of result blobs accessed by output names.
130 * The type of Blob must correspond to the network output precision and size.
132 void SetOutput(const BlobMap &results) {
133 for (auto &&result : results) {
134 CALL_STATUS_FNC(SetBlob, result.first.c_str(), result.second);
139 * @brief Sets new batch size when dynamic batching is enabled in executable network that created this request.
140 * @param batch new batch size to be used by all the following inference calls for this request.
142 void SetBatch(const int batch) {
143 CALL_STATUS_FNC(SetBatch, batch);
147 * constructs InferRequest from initialised shared_pointer
150 explicit InferRequest(IInferRequest::Ptr request) : actual(request) {}
153 * @brief Start inference of specified input(s) in asynchronous mode
154 * @note: It returns immediately. Inference starts also immediately.
157 CALL_STATUS_FNC_NO_ARGS(StartAsync);
161 * @brief Wraps original method
162 * IInferRequest::Wait
164 StatusCode Wait(int64_t millis_timeout) {
165 return actual->Wait(millis_timeout, nullptr);
169 * @brief Wraps original method
170 * IInferRequest::SetCompletionCallback
172 * @param callbackToSet Lambda callback object which will be called on processing finish.
175 void SetCompletionCallback(const T & callbackToSet) {
176 callback.reset(new details::CompletionCallbackWrapper<T>(callbackToSet));
177 CALL_STATUS_FNC(SetUserData, callback.get());
178 actual->SetCompletionCallback(callWrapper);
182 * @brief IInferRequest pointer to be used directly in CreateInferRequest functions
184 operator IInferRequest::Ptr &() {
188 bool operator!() const noexcept {
192 explicit operator bool() const noexcept {
196 using Ptr = std::shared_ptr<InferRequest>;
202 class CompletionCallbackWrapper<std::function<void(InferRequest, StatusCode)>>
203 : public ICompletionCallbackWrapper {
204 std::function<void(InferRequest, StatusCode)> lambda;
206 explicit CompletionCallbackWrapper(const std::function<void(InferRequest, InferenceEngine::StatusCode)> &lambda)
209 void call(InferenceEngine::IInferRequest::Ptr request,
210 InferenceEngine::StatusCode code) const noexcept override {
211 lambda(InferRequest(request), code);
215 } // namespace details
216 } // namespace InferenceEngine