2 // Copyright 2017-2018 Intel Corporation.
4 // This software and the related documents are Intel copyrighted materials,
5 // and your use of them is governed by the express license under which they
6 // were provided to you (End User License Agreement for the Intel(R) Software
7 // Development Products (Version May 2017)). Unless the License provides
8 // otherwise, you may not use, modify, copy, publish, distribute, disclose or
9 // transmit this software or the related documents without Intel's prior
10 // written permission.
12 // This software and the related documents are provided as is, with no
13 // express or implied warranties, other than those that are expressly
14 // stated in the License.
17 #include "hetero_infer_request.h"
19 #include <ie_plugin.hpp>
20 #include <ie_util_internal.hpp>
21 #include <description_buffer.hpp>
23 #include <ie_layouts.h>
25 #include "ie_profiling.hpp"
27 using namespace HeteroPlugin;
28 using namespace InferenceEngine;
30 HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInputs,
31 InferenceEngine::OutputsDataMap networkOutputs,
32 const SubRequestsList &inferRequests) :
33 InferRequestInternal(networkInputs, networkOutputs),
34 _inferRequests(inferRequests) {
35 if (_networkOutputs.empty() || _networkInputs.empty()) {
36 THROW_IE_EXCEPTION << "Internal error: no information about network's output/input";
39 auto requestBlob([&](const std::string &e, InferenceEngine::InferRequest::Ptr r) {
40 if (networkInputs.find(e) != networkInputs.end()) {
41 if (_blobs.find(e) != _blobs.end()) {
42 r->SetBlob(e.c_str(), _blobs[e]);
44 _blobs[e] = r->GetBlob(e.c_str());
45 _inputs[e] = _blobs[e];
47 } else if (networkOutputs.find(e) != networkOutputs.end()) {
48 if (_blobs.find(e) != _blobs.end()) {
49 r->SetBlob(e.c_str(), _blobs[e]);
51 _blobs[e] = r->GetBlob(e.c_str());
52 _outputs[e] = _blobs[e];
55 if (_blobs.find(e) != _blobs.end()) {
56 r->SetBlob(e.c_str(), _blobs[e]);
58 _blobs[e] = r->GetBlob(e.c_str());
63 // go over all subnet and create requests
64 for (auto &&ireq : _inferRequests) {
65 ireq._request = ireq._network->CreateInferRequestPtr();
66 // go over all inputs and get blobs from subnet infer requests
67 for (auto e : ireq._oNames) {
68 requestBlob(e, ireq._request);
72 // go over all outputs and get blobs from subnet infer requests
73 for (auto r : _inferRequests) {
74 for (auto e : r._iNames) {
75 requestBlob(e, r._request);
80 void HeteroInferRequest::InferImpl() {
81 updateInOutIfNeeded();
83 for (auto &&desc : _inferRequests) {
84 IE_PROFILING_AUTO_SCOPE_TASK(desc._profilingTask);
85 auto &r = desc._request;
91 void HeteroInferRequest::GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
93 for (size_t i = 0; i < _inferRequests.size(); i++) {
94 auto perfMapRequest = _inferRequests[i]._request->GetPerformanceCounts();
95 for (auto &&r : perfMapRequest) {
96 perfMap[std::string("subgraph") + std::to_string(i) + ": " + r.first] = r.second;
101 void HeteroInferRequest::updateInOutIfNeeded() {
102 IE_PROFILING_AUTO_SCOPE(updateInOutIfNeeded);
103 assert(!_inferRequests.empty());
104 for (auto &&desc : _inferRequests) {
105 auto &r = desc._request;
106 assert(nullptr != r);
107 for (auto &&ioname : desc._iNames) {
108 auto iti = _inputs.find(ioname);
109 if (iti != _inputs.end()) {
110 auto it = _preProcData.find(ioname);
111 if (it != _preProcData.end()) {
112 if (it->second.getRoiBlob() != _blobs[ioname]) {
113 r->SetBlob(ioname.c_str(), it->second.getRoiBlob());
114 _blobs[ioname] = iti->second;
117 if (iti->second != _blobs[ioname]) {
118 r->SetBlob(ioname.c_str(), iti->second);
119 _blobs[ioname] = iti->second;
124 for (auto &&ioname : desc._oNames) {
125 auto ito = _outputs.find(ioname);
126 if (ito != _outputs.end()) {
127 if (ito->second != _blobs[ioname]) {
128 r->SetBlob(ioname.c_str(), ito->second);
129 _blobs[ioname] = ito->second;
136 void HeteroInferRequest::startFirstAsyncRequest() {
137 auto firstAsyncRequest = _inferRequests.begin()->_request;
138 firstAsyncRequest->StartAsync();
141 void HeteroInferRequest::setCallbackForLastRequest(std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>& callback) {
142 auto lastRequest = _inferRequests.back()._request;
143 if (lastRequest) lastRequest->SetCompletionCallback(callback);
146 void HeteroInferRequest::setCallbackSequence() {
147 for (auto desc = _inferRequests.begin(); desc != _inferRequests.end(); desc++) {
148 auto ¤tAsyncRequest = desc->_request;
149 auto nextRequestDesc = std::next(desc);
150 if (nextRequestDesc != _inferRequests.end()) {
151 currentAsyncRequest->SetCompletionCallback<std::function<void(InferRequest, StatusCode)>>(
152 [=](InferRequest request, StatusCode sts) {
153 IE_PROFILING_AUTO_SCOPE(Callback)
155 nextRequestDesc->_request->StartAsync();
162 StatusCode HeteroInferRequest::waitAllRequests(int64_t millis_timeout) {
163 StatusCode status = INFER_NOT_STARTED;
164 bool shareMsMode = true;
165 std::chrono::high_resolution_clock::time_point startTime;
167 if (millis_timeout == IInferRequest::WaitMode::STATUS_ONLY ||
168 millis_timeout == IInferRequest::WaitMode::RESULT_READY) {
171 for (auto it = _inferRequests.begin(); it != _inferRequests.end(); ++it) {
172 startTime = std::chrono::high_resolution_clock::now();
173 status = it->_request->Wait(millis_timeout);
174 msLeft = std::chrono::duration_cast<std::chrono::milliseconds>(
175 std::chrono::high_resolution_clock::now() - startTime).count();
180 if (millis_timeout - msLeft > 0) {
181 millis_timeout -= msLeft;
182 } else if (it != _inferRequests.end()) {
183 return RESULT_NOT_READY;