Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / hetero_plugin / hetero_infer_request.cpp
1 //
2 // Copyright 2017-2018 Intel Corporation.
3 //
4 // This software and the related documents are Intel copyrighted materials,
5 // and your use of them is governed by the express license under which they
6 // were provided to you (End User License Agreement for the Intel(R) Software
7 // Development Products (Version May 2017)). Unless the License provides
8 // otherwise, you may not use, modify, copy, publish, distribute, disclose or
9 // transmit this software or the related documents without Intel's prior
10 // written permission.
11 //
12 // This software and the related documents are provided as is, with no
13 // express or implied warranties, other than those that are expressly
14 // stated in the License.
15 //
16
17 #include "hetero_infer_request.h"
18 #include <ie_blob.h>
19 #include <ie_plugin.hpp>
20 #include <ie_util_internal.hpp>
21 #include <description_buffer.hpp>
22 #include <debug.h>
23 #include <ie_layouts.h>
24 #include <assert.h>
25 #include "ie_profiling.hpp"
26
27 using namespace HeteroPlugin;
28 using namespace InferenceEngine;
29
30 HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInputs,
31                                        InferenceEngine::OutputsDataMap networkOutputs,
32                                        const SubRequestsList &inferRequests) :
33         InferRequestInternal(networkInputs, networkOutputs),
34         _inferRequests(inferRequests) {
35     if (_networkOutputs.empty() || _networkInputs.empty()) {
36         THROW_IE_EXCEPTION << "Internal error: no information about network's output/input";
37     }
38
39     auto requestBlob([&](const std::string &e, InferenceEngine::InferRequest::Ptr r) {
40         if (networkInputs.find(e) != networkInputs.end()) {
41             if (_blobs.find(e) != _blobs.end()) {
42                 r->SetBlob(e.c_str(), _blobs[e]);
43             } else {
44                 _blobs[e] = r->GetBlob(e.c_str());
45                 _inputs[e] = _blobs[e];
46             }
47         } else if (networkOutputs.find(e) != networkOutputs.end()) {
48             if (_blobs.find(e) != _blobs.end()) {
49                 r->SetBlob(e.c_str(), _blobs[e]);
50             } else {
51                 _blobs[e] = r->GetBlob(e.c_str());
52                 _outputs[e] = _blobs[e];
53             }
54         } else {
55             if (_blobs.find(e) != _blobs.end()) {
56                 r->SetBlob(e.c_str(), _blobs[e]);
57             } else {
58                 _blobs[e] = r->GetBlob(e.c_str());
59             }
60         }
61     });
62
63     // go over all subnet and create requests
64     for (auto &&ireq : _inferRequests) {
65         ireq._request = ireq._network->CreateInferRequestPtr();
66         // go over all inputs and get blobs from subnet infer requests
67         for (auto e : ireq._oNames) {
68             requestBlob(e, ireq._request);
69         }
70     }
71
72     // go over all outputs and get blobs from subnet infer requests
73     for (auto r : _inferRequests) {
74         for (auto e : r._iNames) {
75             requestBlob(e, r._request);
76         }
77     }
78 }
79
80 void HeteroInferRequest::InferImpl() {
81     updateInOutIfNeeded();
82     size_t i = 0;
83     for (auto &&desc : _inferRequests) {
84         IE_PROFILING_AUTO_SCOPE_TASK(desc._profilingTask);
85         auto &r = desc._request;
86         assert(nullptr != r);
87         r->Infer();
88     }
89 }
90
91 void HeteroInferRequest::GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
92     perfMap.clear();
93     for (size_t i = 0; i < _inferRequests.size(); i++) {
94         auto perfMapRequest = _inferRequests[i]._request->GetPerformanceCounts();
95         for (auto &&r : perfMapRequest) {
96             perfMap[std::string("subgraph") + std::to_string(i) + ": " + r.first] = r.second;
97         }
98     }
99 }
100
101 void HeteroInferRequest::updateInOutIfNeeded() {
102     IE_PROFILING_AUTO_SCOPE(updateInOutIfNeeded);
103     assert(!_inferRequests.empty());
104     for (auto &&desc : _inferRequests) {
105         auto &r = desc._request;
106         assert(nullptr != r);
107         for (auto &&ioname : desc._iNames) {
108             auto iti = _inputs.find(ioname);
109             if (iti != _inputs.end()) {
110                 auto it = _preProcData.find(ioname);
111                 if (it != _preProcData.end()) {
112                     if (it->second.getRoiBlob() != _blobs[ioname]) {
113                         r->SetBlob(ioname.c_str(), it->second.getRoiBlob());
114                         _blobs[ioname] = iti->second;
115                     }
116                 } else {
117                     if (iti->second != _blobs[ioname]) {
118                         r->SetBlob(ioname.c_str(), iti->second);
119                         _blobs[ioname] = iti->second;
120                     }
121                 }
122             }
123         }
124         for (auto &&ioname : desc._oNames) {
125             auto ito = _outputs.find(ioname);
126             if (ito != _outputs.end()) {
127                 if (ito->second != _blobs[ioname]) {
128                     r->SetBlob(ioname.c_str(), ito->second);
129                     _blobs[ioname] = ito->second;
130                 }
131             }
132         }
133     }
134 }
135
136 void HeteroInferRequest::startFirstAsyncRequest() {
137     auto firstAsyncRequest = _inferRequests.begin()->_request;
138     firstAsyncRequest->StartAsync();
139 }
140
141 void HeteroInferRequest::setCallbackForLastRequest(std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>& callback) {
142     auto lastRequest = _inferRequests.back()._request;
143     if (lastRequest) lastRequest->SetCompletionCallback(callback);
144 }
145
146 void HeteroInferRequest::setCallbackSequence() {
147     for (auto desc = _inferRequests.begin(); desc != _inferRequests.end(); desc++) {
148         auto &currentAsyncRequest = desc->_request;
149         auto nextRequestDesc = std::next(desc);
150         if (nextRequestDesc != _inferRequests.end()) {
151             currentAsyncRequest->SetCompletionCallback<std::function<void(InferRequest, StatusCode)>>(
152                     [=](InferRequest request, StatusCode sts) {
153                         IE_PROFILING_AUTO_SCOPE(Callback)
154                         if (sts == OK) {
155                             nextRequestDesc->_request->StartAsync();
156                         }
157                     });
158         }
159     }
160 }
161
162 StatusCode HeteroInferRequest::waitAllRequests(int64_t millis_timeout) {
163     StatusCode status = INFER_NOT_STARTED;
164     bool shareMsMode = true;
165     std::chrono::high_resolution_clock::time_point startTime;
166     int64_t msLeft;
167     if (millis_timeout == IInferRequest::WaitMode::STATUS_ONLY ||
168         millis_timeout == IInferRequest::WaitMode::RESULT_READY) {
169         shareMsMode = false;
170     }
171     for (auto it = _inferRequests.begin(); it != _inferRequests.end(); ++it) {
172         startTime = std::chrono::high_resolution_clock::now();
173         status = it->_request->Wait(millis_timeout);
174         msLeft = std::chrono::duration_cast<std::chrono::milliseconds>(
175                 std::chrono::high_resolution_clock::now() - startTime).count();
176         if (OK != status) {
177             return status;
178         }
179         if (shareMsMode) {
180             if (millis_timeout - msLeft > 0) {
181                 millis_timeout -= msLeft;
182             } else if (it != _inferRequests.end()) {
183                 return RESULT_NOT_READY;
184             }
185         }
186     }
187     return status;
188 }