1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 * \brief inference engine plugin API wrapper, to be used by particular implementors
7 * \file ie_plugin_base.hpp
15 #include <blob_factory.hpp>
16 #include "graph_transformer.h"
17 #include "cpp_interfaces/interface/ie_iplugin_internal.hpp"
18 #include "cpp_interfaces/base/ie_executable_network_base.hpp"
19 #include "cpp_interfaces/impl/ie_executable_network_internal.hpp"
20 #include "ie_memcpy.h"
22 namespace InferenceEngine {
25 * @brief optional implementation of IInferencePluginInternal to avoid duplication in all plugins
27 class InferencePluginInternal
28 : public IInferencePluginInternal, public std::enable_shared_from_this<InferencePluginInternal> {
31 * Given optional implementation of deprecated load to avoid need for it to be implemented by plugin
33 void LoadNetwork(ICNNNetwork &network) override {
34 _isDeprecatedLoad = true;
35 network.getInputsInfo(_networkInputs);
36 network.getOutputsInfo(_networkOutputs);
37 if (_networkInputs.empty() || _networkOutputs.empty()) {
38 THROW_IE_EXCEPTION << "The network doesn't have inputs/outputs.";
40 _createdInferRequest = nullptr; // first release the infer request
41 _loadedNetwork = nullptr; // first release the loaded network
43 _firstInput = _networkInputs.begin()->first;
44 _firstOutput = _networkOutputs.begin()->first;
45 LoadNetwork(_loadedNetwork, network, {});
48 StatusCode sts = _loadedNetwork->CreateInferRequest(_createdInferRequest, &resp);
49 if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
52 * @brief most plugins successfully consume unreshapable networks - lets do it in base class
53 * WARNING: this functions modifies layers in input network and might affect application, that uses it
55 virtual ICNNNetwork& RemoveConstLayers(ICNNNetwork &network) {
56 auto* implNetwork = dynamic_cast<details::CNNNetworkImpl*>(&network);
58 // valid for CNNNetworkImpl only, while there's no API in ICNNNetwork to change network
59 ConstTransformer transformator(implNetwork);
60 transformator.fullTrim();
66 * @brief Creates an executable network from an pares network object, users can create as many networks as they need and use
67 * them simultaneously (up to the limitation of the HW resources)
68 * @param network - a network object acquired from CNNNetReader
69 * @param config string-string map of config parameters relevant only for this load operation
70 * @return shared_ptr to the ExecutableNetwork object
72 virtual ExecutableNetworkInternal::Ptr LoadExeNetworkImpl(ICNNNetwork &network,
73 const std::map<std::string, std::string> &config) = 0;
76 * Given optional implementation of load executable network to avoid need for it to be implemented by plugin
78 void LoadNetwork(IExecutableNetwork::Ptr &executableNetwork,
80 const std::map<std::string, std::string> &config) override {
81 InputsDataMap networkInputs;
82 OutputsDataMap networkOutputs;
83 network.getInputsInfo(networkInputs);
84 network.getOutputsInfo(networkOutputs);
85 _networkInputs.clear();
86 _networkOutputs.clear();
88 for (const auto& it : networkInputs) {
89 InputInfo::Ptr newPtr;
91 newPtr.reset(new InputInfo());
92 DataPtr newData(new Data(*it.second->getInputData()));
93 newPtr->getPreProcess() = it.second->getPreProcess();
94 if (newPtr->getPreProcess().getMeanVariant() == MEAN_IMAGE) {
95 for (size_t i = 0; i < newPtr->getPreProcess().getNumberOfChannels(); i++) {
96 auto blob = newPtr->getPreProcess()[i]->meanData;
97 newPtr->getPreProcess()[i]->meanData =
98 make_blob_with_precision(newPtr->getPreProcess()[i]->meanData->getTensorDesc());
99 newPtr->getPreProcess()[i]->meanData->allocate();
100 ie_memcpy(newPtr->getPreProcess()[i]->meanData->buffer(), newPtr->getPreProcess()[i]->meanData->byteSize(),
101 blob->cbuffer(), blob->byteSize());
104 newData->inputTo.clear();
105 newPtr->setInputData(newData);
107 _networkInputs[it.first] = newPtr;
110 for (const auto& it : networkOutputs) {
113 newData.reset(new Data(*it.second));
114 newData->inputTo.clear();
116 _networkOutputs[it.first] = newData;
118 auto impl = LoadExeNetworkImpl(RemoveConstLayers(network), config);
119 impl->setNetworkInputs(_networkInputs);
120 impl->setNetworkOutputs(_networkOutputs);
121 // skip setting shared ptr to avoid curricular dependency: ExecutableNetworkBase -> IExecutableNetworkInternal -> InferencePluginInternal
122 if (!_isDeprecatedLoad) {
123 impl->SetPointerToPluginInternal(shared_from_this());
126 executableNetwork.reset(new ExecutableNetworkBase<ExecutableNetworkInternal>(impl), [](details::IRelease *p) {
129 _isDeprecatedLoad = false;
133 * Given optional implementation of deprecated infer to avoid need for it to be implemented by plugin
135 void Infer(const Blob &input, Blob &result) override {
136 const BlobMap inputs = {{_firstInput, std::shared_ptr<Blob>(const_cast<Blob *>(&input), [](Blob *ptr) {})}};
137 BlobMap results = {{_firstOutput, std::shared_ptr<Blob>(&result, [](Blob *ptr) {})}};
138 return Infer(inputs, results);
142 * Given optional implementation of deprecated infer to avoid need for it to be implemented by plugin
144 void Infer(const BlobMap &input, BlobMap &result) override {
145 if (_createdInferRequest == nullptr) {
146 THROW_IE_EXCEPTION << NETWORK_NOT_LOADED_str;
151 auto setBlobs = [&](const BlobMap &blobMap) {
152 for (auto pair : blobMap) {
153 auto blobName = pair.first;
154 auto blobPtr = pair.second;
155 sts = _createdInferRequest->SetBlob(blobName.c_str(), blobPtr, &resp);
156 if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
162 sts = _createdInferRequest->Infer(&resp);
163 if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
167 * Given optional implementation of deprecated infer to avoid need for it to be implemented by plugin
169 void GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo> &perfMap) override {
170 if (_createdInferRequest == nullptr) {
171 THROW_IE_EXCEPTION << NETWORK_NOT_LOADED_str;
174 StatusCode sts = _createdInferRequest->GetPerformanceCounts(perfMap, &resp);
175 if (sts != OK) THROW_IE_EXCEPTION << resp.msg;
179 * Given optional implementation of ImportNetwork to avoid need for it to be implemented by plugin
181 IExecutableNetwork::Ptr ImportNetwork(const std::string &modelFileName, const std::map<std::string, std::string> &config) override {
182 THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
186 * Given optional implementation of SetConfig to avoid need for it to be implemented by plugin
188 void SetConfig(const std::map<std::string, std::string> &config) override {}
191 * Given optional implementation of SetLogCallback to avoid need for it to be implemented by plugin
193 void SetLogCallback(IErrorListener &listener) override {}
196 * Given optional implementation of AddExtension to avoid need for it to be implemented by plugin
198 void AddExtension(InferenceEngine::IExtensionPtr extension) override {
199 THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
202 * @depricated Use the version with config parameter
204 void QueryNetwork(const ICNNNetwork& network, QueryNetworkResult& res) const override {
205 QueryNetwork(network, {}, res);
208 void QueryNetwork(const ICNNNetwork &network, const std::map<std::string, std::string>& config, QueryNetworkResult &res) const override {
209 THROW_IE_EXCEPTION << NOT_IMPLEMENTED_str;
214 IExecutableNetwork::Ptr _loadedNetwork;
215 std::string _firstInput;
216 std::string _firstOutput;
217 IInferRequest::Ptr _createdInferRequest;
218 InferenceEngine::InputsDataMap _networkInputs;
219 InferenceEngine::OutputsDataMap _networkOutputs;
220 std::map<std::string, std::string> _config;
221 bool _isDeprecatedLoad = false;
224 } // namespace InferenceEngine