Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / src / hetero_plugin / hetero_infer_request.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "hetero_infer_request.hpp"
6 #include <ie_blob.h>
7 #include <ie_plugin.hpp>
8 #include <ie_util_internal.hpp>
9 #include <description_buffer.hpp>
10 #include <debug.h>
11 #include <ie_layouts.h>
12 #include <cassert>
13 #include <map>
14 #include <string>
15
16 using namespace HeteroPlugin;
17 using namespace InferenceEngine;
18
19 HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInputs,
20                                        InferenceEngine::OutputsDataMap networkOutputs,
21                                        const SubRequestsList &inferRequests) :
22         InferRequestInternal(networkInputs, networkOutputs),
23         _inferRequests(inferRequests) {
24     if (_networkOutputs.empty() || _networkInputs.empty()) {
25         THROW_IE_EXCEPTION << "Internal error: no information about network's output/input";
26     }
27
28     auto requestBlob([&](const std::string &e, InferenceEngine::InferRequest::Ptr r) {
29         if (networkInputs.find(e) != networkInputs.end()) {
30             if (_blobs.find(e) != _blobs.end()) {
31                 r->SetBlob(e.c_str(), _blobs[e]);
32             } else {
33                 _blobs[e] = r->GetBlob(e.c_str());
34                 _inputs[e] = _blobs[e];
35             }
36         } else if (networkOutputs.find(e) != networkOutputs.end()) {
37             if (_blobs.find(e) != _blobs.end()) {
38                 r->SetBlob(e.c_str(), _blobs[e]);
39             } else {
40                 _blobs[e] = r->GetBlob(e.c_str());
41                 _outputs[e] = _blobs[e];
42             }
43         } else {
44             if (_blobs.find(e) != _blobs.end()) {
45                 r->SetBlob(e.c_str(), _blobs[e]);
46             } else {
47                 _blobs[e] = r->GetBlob(e.c_str());
48             }
49         }
50     });
51
52     // go over all subnet and create requests
53     for (auto &&ireq : _inferRequests) {
54         ireq._request = ireq._network->CreateInferRequestPtr();
55         // go over all inputs and get blobs from subnet infer requests
56         for (auto e : ireq._oNames) {
57             requestBlob(e, ireq._request);
58         }
59     }
60
61     // go over all outputs and get blobs from subnet infer requests
62     for (auto r : _inferRequests) {
63         for (auto e : r._iNames) {
64             requestBlob(e, r._request);
65         }
66     }
67 }
68
69 void HeteroInferRequest::InferImpl() {
70     updateInOutIfNeeded();
71     size_t i = 0;
72     for (auto &&desc : _inferRequests) {
73         IE_PROFILING_AUTO_SCOPE_TASK(desc._profilingTask);
74         auto &r = desc._request;
75         assert(nullptr != r);
76         r->Infer();
77     }
78 }
79
80 void HeteroInferRequest::GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
81     perfMap.clear();
82     for (size_t i = 0; i < _inferRequests.size(); i++) {
83         auto perfMapRequest = _inferRequests[i]._request->GetPerformanceCounts();
84         for (auto &&r : perfMapRequest) {
85             perfMap[std::string("subgraph") + std::to_string(i) + ": " + r.first] = r.second;
86         }
87     }
88 }
89
90 void HeteroInferRequest::updateInOutIfNeeded() {
91     IE_PROFILING_AUTO_SCOPE(updateInOutIfNeeded);
92     assert(!_inferRequests.empty());
93     for (auto &&desc : _inferRequests) {
94         auto &r = desc._request;
95         assert(nullptr != r);
96         for (auto &&ioname : desc._iNames) {
97             auto iti = _inputs.find(ioname);
98             if (iti != _inputs.end()) {
99                 auto it = _preProcData.find(ioname);
100                 if (it != _preProcData.end()) {
101                     if (it->second.getRoiBlob() != _blobs[ioname]) {
102                         r->SetBlob(ioname.c_str(), it->second.getRoiBlob());
103                         _blobs[ioname] = iti->second;
104                     }
105                 } else {
106                     if (iti->second != _blobs[ioname]) {
107                         r->SetBlob(ioname.c_str(), iti->second);
108                         _blobs[ioname] = iti->second;
109                     }
110                 }
111             }
112         }
113         for (auto &&ioname : desc._oNames) {
114             auto ito = _outputs.find(ioname);
115             if (ito != _outputs.end()) {
116                 if (ito->second != _blobs[ioname]) {
117                     r->SetBlob(ioname.c_str(), ito->second);
118                     _blobs[ioname] = ito->second;
119                 }
120             }
121         }
122     }
123 }
124
125 void HeteroInferRequest::startFirstAsyncRequest() {
126     auto firstAsyncRequest = _inferRequests.begin()->_request;
127     firstAsyncRequest->StartAsync();
128 }
129
130 void HeteroInferRequest::setCallbackForLastRequest(std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>& callback) {
131     auto lastRequest = _inferRequests.back()._request;
132     if (lastRequest) lastRequest->SetCompletionCallback(callback);
133 }
134
135 void HeteroInferRequest::setCallbackSequence() {
136     for (auto desc = _inferRequests.begin(); desc != _inferRequests.end(); desc++) {
137         auto &currentAsyncRequest = desc->_request;
138         auto nextRequestDesc = std::next(desc);
139         if (nextRequestDesc != _inferRequests.end()) {
140             currentAsyncRequest->SetCompletionCallback<std::function<void(InferRequest, StatusCode)>>(
141                     [=](InferRequest request, StatusCode sts) {
142                         IE_PROFILING_AUTO_SCOPE(Callback)
143                         if (sts == OK) {
144                             nextRequestDesc->_request->StartAsync();
145                         }
146                     });
147         }
148     }
149 }
150
151 StatusCode HeteroInferRequest::waitAllRequests(int64_t millis_timeout) {
152     StatusCode status = INFER_NOT_STARTED;
153     bool shareMsMode = true;
154     std::chrono::high_resolution_clock::time_point startTime;
155     int64_t msLeft;
156     if (millis_timeout == IInferRequest::WaitMode::STATUS_ONLY ||
157         millis_timeout == IInferRequest::WaitMode::RESULT_READY) {
158         shareMsMode = false;
159     }
160     for (auto it = _inferRequests.begin(); it != _inferRequests.end(); ++it) {
161         startTime = std::chrono::high_resolution_clock::now();
162         status = it->_request->Wait(millis_timeout);
163         msLeft = std::chrono::duration_cast<std::chrono::milliseconds>(
164                 std::chrono::high_resolution_clock::now() - startTime).count();
165         if (OK != status) {
166             return status;
167         }
168         if (shareMsMode) {
169             if (millis_timeout - msLeft > 0) {
170                 millis_timeout -= msLeft;
171             } else if (it != _inferRequests.end()) {
172                 return RESULT_NOT_READY;
173             }
174         }
175     }
176     return status;
177 }