Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / cpp_interfaces / impl / ie_plugin_internal.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * \brief inference engine plugin API wrapper, to be used by particular implementors
7  * \file ie_plugin_base.hpp
8  */
9
10 #pragma once
11
12 #include <memory>
13 #include <map>
14 #include <string>
15 #include <blob_factory.hpp>
16 #include "graph_transformer.h"
17 #include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
18 #include "cpp_interfaces/base/ie_executable_network_base.hpp"
19 #include "cpp_interfaces/impl/ie_executable_network_internal.hpp"
20 #include "ie_memcpy.h"
21
22 namespace InferenceEngine {
23
24 /**
25  * @brief optional implementation of IInferencePluginInternal to avoid duplication in all plugins
26  */
27 class InferencePluginInternal
28         : public IInferencePluginInternal, public std::enable_shared_from_this<InferencePluginInternal> {
29 public:
30     /**
31      * Given optional implementation of deprecated load to avoid need for it to be implemented by plugin
32      */
33     void LoadNetwork(ICNNNetwork &network) override {
34         _isDeprecatedLoad = true;
35         network.getInputsInfo(_networkInputs);
36         network.getOutputsInfo(_networkOutputs);
37         if (_networkInputs.empty() || _networkOutputs.empty()) {
38             THROW_IE_EXCEPTION << "The network doesn't have inputs/outputs.";
39         }
40         _createdInferRequest = nullptr;  // first release the infer request
41         _loadedNetwork = nullptr;  // first release the loaded network
42
43         _firstInput = _networkInputs.begin()->first;
44         _firstOutput = _networkOutputs.begin()->first;
45         LoadNetwork(_loadedNetwork, network, {});
46
47         ResponseDesc resp;
48         StatusCode sts = _loadedNetwork->CreateInferRequest(_createdInferRequest, &resp);
49         if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
50     }
51     /**
52      * @brief most plugins successfully consume unreshapable networks - lets do it in base class
53      * WARNING: this functions modifies layers in input network and might affect application, that uses it
54      */
55     virtual ICNNNetwork&  RemoveConstLayers(ICNNNetwork &network) {
56         auto* implNetwork = dynamic_cast<details::CNNNetworkImpl*>(&network);
57         if (implNetwork) {
58             // valid for CNNNetworkImpl only, while there's no API in ICNNNetwork to change network
59             ConstTransformer transformator(implNetwork);
60             transformator.fullTrim();
61         }
62         return network;
63     }
64
65     /**
66      * @brief Creates an executable network from an pares network object, users can create as many networks as they need and use
67      *        them simultaneously (up to the limitation of the HW resources)
68      * @param network - a network object acquired from CNNNetReader
69      * @param config string-string map of config parameters relevant only for this load operation
70      * @return shared_ptr to the ExecutableNetwork object
71      */
72     virtual ExecutableNetworkInternal::Ptr LoadExeNetworkImpl(ICNNNetwork &network,
73                                                               const std::map<std::string, std::string> &config) = 0;
74
75     /**
76      * Given optional implementation of load executable network to avoid need for it to be implemented by plugin
77      */
78     void LoadNetwork(IExecutableNetwork::Ptr &executableNetwork,
79                      ICNNNetwork &network,
80                      const std::map<std::string, std::string> &config) override {
81         InputsDataMap networkInputs;
82         OutputsDataMap networkOutputs;
83         network.getInputsInfo(networkInputs);
84         network.getOutputsInfo(networkOutputs);
85         _networkInputs.clear();
86         _networkOutputs.clear();
87
88         for (const auto& it : networkInputs) {
89             InputInfo::Ptr newPtr;
90             if (it.second) {
91                 newPtr.reset(new InputInfo());
92                 DataPtr newData(new Data(*it.second->getInputData()));
93                 newPtr->getPreProcess() = it.second->getPreProcess();
94                 if (newPtr->getPreProcess().getMeanVariant() == MEAN_IMAGE) {
95                     for (size_t i = 0; i < newPtr->getPreProcess().getNumberOfChannels(); i++) {
96                         auto blob = newPtr->getPreProcess()[i]->meanData;
97                         newPtr->getPreProcess()[i]->meanData =
98                                 make_blob_with_precision(newPtr->getPreProcess()[i]->meanData->getTensorDesc());
99                         newPtr->getPreProcess()[i]->meanData->allocate();
100                         ie_memcpy(newPtr->getPreProcess()[i]->meanData->buffer(), newPtr->getPreProcess()[i]->meanData->byteSize(),
101                                   blob->cbuffer(), blob->byteSize());
102                     }
103                 }
104                 newData->inputTo.clear();
105                 newPtr->setInputData(newData);
106             }
107             _networkInputs[it.first] = newPtr;
108         }
109
110         for (const auto& it : networkOutputs) {
111             DataPtr newData;
112             if (it.second) {
113                 newData.reset(new Data(*it.second));
114                 newData->inputTo.clear();
115             }
116             _networkOutputs[it.first] = newData;
117         }
118         auto impl = LoadExeNetworkImpl(RemoveConstLayers(network), config);
119         impl->setNetworkInputs(_networkInputs);
120         impl->setNetworkOutputs(_networkOutputs);
121         // skip setting shared ptr to avoid curricular dependency: ExecutableNetworkBase -> IExecutableNetworkInternal -> InferencePluginInternal
122         if (!_isDeprecatedLoad) {
123             impl->SetPointerToPluginInternal(shared_from_this());
124         }
125
126         executableNetwork.reset(new ExecutableNetworkBase<ExecutableNetworkInternal>(impl), [](details::IRelease *p) {
127             p->Release();
128         });
129         _isDeprecatedLoad = false;
130     };
131
132     /**
133      * Given optional implementation of deprecated infer to avoid need for it to be implemented by plugin
134      */
135     void Infer(const Blob &input, Blob &result) override {
136         const BlobMap inputs = {{_firstInput, std::shared_ptr<Blob>(const_cast<Blob *>(&input), [](Blob *ptr) {})}};
137         BlobMap results = {{_firstOutput, std::shared_ptr<Blob>(&result, [](Blob *ptr) {})}};
138         return Infer(inputs, results);
139     }
140
141     /**
142      * Given optional implementation of deprecated infer to avoid need for it to be implemented by plugin
143      */
144     void Infer(const BlobMap &input, BlobMap &result) override {
145         if (_createdInferRequest == nullptr) {
146             THROW_IE_EXCEPTION << NETWORK_NOT_LOADED_str;
147         }
148         ResponseDesc resp;
149         StatusCode sts;
150
151         auto setBlobs = [&](const BlobMap &blobMap) {
152             for (auto pair : blobMap) {
153                 auto blobName = pair.first;
154                 auto blobPtr = pair.second;
155                 sts = _createdInferRequest->SetBlob(blobName.c_str(), blobPtr, &resp);
156                 if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
157             }
158         };
159         setBlobs(input);
160         setBlobs(result);
161
162         sts = _createdInferRequest->Infer(&resp);
163         if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
164     }
165
166     /**
167      * Given optional implementation of deprecated infer to avoid need for it to be implemented by plugin
168      */
169     void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) override {
170         if (_createdInferRequest == nullptr) {
171             THROW_IE_EXCEPTION << NETWORK_NOT_LOADED_str;
172         }
173         ResponseDesc resp;
174         StatusCode sts = _createdInferRequest->GetPerformanceCounts(perfMap, &resp);
175         if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
176     }
177
178     /**
179      * Given optional implementation of ImportNetwork to avoid need for it to be implemented by plugin
180      */
181     IExecutableNetwork::Ptr ImportNetwork(const std::string &modelFileName, const std::map<std::string, std::string> &config) override {
182         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
183     }
184
185     /**
186      * Given optional implementation of SetConfig to avoid need for it to be implemented by plugin
187      */
188     void SetConfig(const std::map<std::string, std::string> &config) override {}
189
190     /**
191      * Given optional implementation of SetLogCallback to avoid need for it to be implemented by plugin
192      */
193     void SetLogCallback(IErrorListener &listener) override {}
194
195     /**
196      * Given optional implementation of AddExtension to avoid need for it to be implemented by plugin
197      */
198     void AddExtension(InferenceEngine::IExtensionPtr extension) override {
199         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
200     }
201     /**
202      * @depricated Use the version with config parameter
203      */
204     void QueryNetwork(const ICNNNetwork& network, QueryNetworkResult& res) const override {
205         QueryNetwork(network, {}, res);
206     }
207
208     void QueryNetwork(const ICNNNetwork &network, const std::map<std::string, std::string>& config, QueryNetworkResult &res) const override {
209         THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
210     }
211
212
213 protected:
214     IExecutableNetwork::Ptr _loadedNetwork;
215     std::string _firstInput;
216     std::string _firstOutput;
217     IInferRequest::Ptr _createdInferRequest;
218     InferenceEngine::InputsDataMap _networkInputs;
219     InferenceEngine::OutputsDataMap _networkOutputs;
220     std::map<std::string, std::string> _config;
221     bool _isDeprecatedLoad = false;
222 };
223
224 }  // namespace InferenceEngine