1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "hetero_infer_request.hpp"
7 #include <ie_plugin.hpp>
8 #include <ie_util_internal.hpp>
9 #include <description_buffer.hpp>
11 #include <ie_layouts.h>
16 using namespace HeteroPlugin;
17 using namespace InferenceEngine;
19 HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInputs,
20 InferenceEngine::OutputsDataMap networkOutputs,
21 const SubRequestsList &inferRequests) :
22 InferRequestInternal(networkInputs, networkOutputs),
23 _inferRequests(inferRequests) {
24 if (_networkOutputs.empty() || _networkInputs.empty()) {
25 THROW_IE_EXCEPTION << "Internal error: no information about network's output/input";
28 auto requestBlob([&](const std::string &e, InferenceEngine::InferRequest::Ptr r) {
29 if (networkInputs.find(e) != networkInputs.end()) {
30 if (_blobs.find(e) != _blobs.end()) {
31 r->SetBlob(e.c_str(), _blobs[e]);
33 _blobs[e] = r->GetBlob(e.c_str());
34 _inputs[e] = _blobs[e];
36 } else if (networkOutputs.find(e) != networkOutputs.end()) {
37 if (_blobs.find(e) != _blobs.end()) {
38 r->SetBlob(e.c_str(), _blobs[e]);
40 _blobs[e] = r->GetBlob(e.c_str());
41 _outputs[e] = _blobs[e];
44 if (_blobs.find(e) != _blobs.end()) {
45 r->SetBlob(e.c_str(), _blobs[e]);
47 _blobs[e] = r->GetBlob(e.c_str());
52 // go over all subnet and create requests
53 for (auto &&ireq : _inferRequests) {
54 ireq._request = ireq._network->CreateInferRequestPtr();
55 // go over all inputs and get blobs from subnet infer requests
56 for (auto e : ireq._oNames) {
57 requestBlob(e, ireq._request);
61 // go over all outputs and get blobs from subnet infer requests
62 for (auto r : _inferRequests) {
63 for (auto e : r._iNames) {
64 requestBlob(e, r._request);
69 void HeteroInferRequest::InferImpl() {
70 updateInOutIfNeeded();
72 for (auto &&desc : _inferRequests) {
73 IE_PROFILING_AUTO_SCOPE_TASK(desc._profilingTask);
74 auto &r = desc._request;
80 void HeteroInferRequest::GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) const {
82 for (size_t i = 0; i < _inferRequests.size(); i++) {
83 auto perfMapRequest = _inferRequests[i]._request->GetPerformanceCounts();
84 for (auto &&r : perfMapRequest) {
85 perfMap[std::string("subgraph") + std::to_string(i) + ": " + r.first] = r.second;
90 void HeteroInferRequest::updateInOutIfNeeded() {
91 IE_PROFILING_AUTO_SCOPE(updateInOutIfNeeded);
92 assert(!_inferRequests.empty());
93 for (auto &&desc : _inferRequests) {
94 auto &r = desc._request;
96 for (auto &&ioname : desc._iNames) {
97 auto iti = _inputs.find(ioname);
98 if (iti != _inputs.end()) {
99 auto it = _preProcData.find(ioname);
100 if (it != _preProcData.end()) {
101 if (it->second.getRoiBlob() != _blobs[ioname]) {
102 r->SetBlob(ioname.c_str(), it->second.getRoiBlob());
103 _blobs[ioname] = iti->second;
106 if (iti->second != _blobs[ioname]) {
107 r->SetBlob(ioname.c_str(), iti->second);
108 _blobs[ioname] = iti->second;
113 for (auto &&ioname : desc._oNames) {
114 auto ito = _outputs.find(ioname);
115 if (ito != _outputs.end()) {
116 if (ito->second != _blobs[ioname]) {
117 r->SetBlob(ioname.c_str(), ito->second);
118 _blobs[ioname] = ito->second;
125 void HeteroInferRequest::startFirstAsyncRequest() {
126 auto firstAsyncRequest = _inferRequests.begin()->_request;
127 firstAsyncRequest->StartAsync();
130 void HeteroInferRequest::setCallbackForLastRequest(std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>& callback) {
131 auto lastRequest = _inferRequests.back()._request;
132 if (lastRequest) lastRequest->SetCompletionCallback(callback);
135 void HeteroInferRequest::setCallbackSequence() {
136 for (auto desc = _inferRequests.begin(); desc != _inferRequests.end(); desc++) {
137 auto ¤tAsyncRequest = desc->_request;
138 auto nextRequestDesc = std::next(desc);
139 if (nextRequestDesc != _inferRequests.end()) {
140 currentAsyncRequest->SetCompletionCallback<std::function<void(InferRequest, StatusCode)>>(
141 [=](InferRequest request, StatusCode sts) {
142 IE_PROFILING_AUTO_SCOPE(Callback)
144 nextRequestDesc->_request->StartAsync();
151 StatusCode HeteroInferRequest::waitAllRequests(int64_t millis_timeout) {
152 StatusCode status = INFER_NOT_STARTED;
153 bool shareMsMode = true;
154 std::chrono::high_resolution_clock::time_point startTime;
156 if (millis_timeout == IInferRequest::WaitMode::STATUS_ONLY ||
157 millis_timeout == IInferRequest::WaitMode::RESULT_READY) {
160 for (auto it = _inferRequests.begin(); it != _inferRequests.end(); ++it) {
161 startTime = std::chrono::high_resolution_clock::now();
162 status = it->_request->Wait(millis_timeout);
163 msLeft = std::chrono::duration_cast<std::chrono::milliseconds>(
164 std::chrono::high_resolution_clock::now() - startTime).count();
169 if (millis_timeout - msLeft > 0) {
170 millis_timeout -= msLeft;
171 } else if (it != _inferRequests.end()) {
172 return RESULT_NOT_READY;