CMAKE: moved GNA var setting to proper place; removed find_package when build python...
[platform/upstream/dldt.git] / inference-engine / src / hetero_plugin / hetero_plugin.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ie_metric_helpers.hpp"
6 #include "hetero_plugin.hpp"
7 #include "ie_util_internal.hpp"
8 #include <memory>
9 #include <vector>
10 #include <map>
11 #include <string>
12 #include <utility>
13 #include <fstream>
14 #include <unordered_set>
15 #include "ie_plugin_config.hpp"
16 #include "hetero/hetero_plugin_config.hpp"
17 #include <cpp_interfaces/base/ie_plugin_base.hpp>
18 #include "hetero_executable_network.hpp"
19 #include "cpp_interfaces/base/ie_inference_plugin_api.hpp"
20
21 using namespace InferenceEngine;
22 using namespace InferenceEngine::PluginConfigParams;
23 using namespace InferenceEngine::HeteroConfigParams;
24 using namespace HeteroPlugin;
25 using namespace std;
26
27 static Version heteroPluginDescription = {
28         {2, 1},  // plugin API version
29         CI_BUILD_NUMBER,
30         "heteroPlugin"  // plugin description message
31 };
32
33 void Engine::GetVersion(const Version *&versionInfo)noexcept {
34     versionInfo = &heteroPluginDescription;
35 }
36
37 Engine::Engine() {
38     _pluginName = "HETERO";
39     _config[InferenceEngine::PluginConfigParams::KEY_EXCLUSIVE_ASYNC_REQUESTS] = "YES";
40     _config[HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)] = NO;
41 }
42
43 InferenceEngine::ExecutableNetworkInternal::Ptr Engine::LoadExeNetworkImpl(const ICore*                     core,
44                                                                            InferenceEngine::ICNNNetwork&    network,
45                                                                            const Configs&                   config) {
46     // TODO(amalyshe) do we need here verification of input precisions?
47     Configs tconfig;
48     tconfig = config;
49
50     // we must not override the parameter, but need to copy everything from plugin config
51     for (auto && c : _config) {
52         if (tconfig.find(c.first) == tconfig.end()) {
53             tconfig[c.first] = c.second;
54         }
55     }
56
57     return std::make_shared<HeteroExecutableNetwork>(network, tconfig, this);
58 }
59
60 namespace  {
61
62 IInferencePluginAPI * getInferencePluginAPIInterface(IInferencePlugin * iplugin) {
63     return dynamic_cast<IInferencePluginAPI *>(iplugin);
64 }
65
66 IInferencePluginAPI * getInferencePluginAPIInterface(InferenceEnginePluginPtr iplugin) {
67     return getInferencePluginAPIInterface(static_cast<IInferencePlugin *>(iplugin.operator->()));
68 }
69
70 IInferencePluginAPI * getInferencePluginAPIInterface(InferencePlugin plugin) {
71     return getInferencePluginAPIInterface(static_cast<InferenceEnginePluginPtr>(plugin));
72 }
73
74 }  // namespace
75
76
77 Engine::Configs Engine::GetSupportedConfig(const Engine::Configs& config, const InferenceEngine::InferencePlugin& plugin) {
78     auto pluginApi = getInferencePluginAPIInterface(plugin);
79     std::vector<std::string> supportedConfigKeys = pluginApi->GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), {});
80     Engine::Configs supportedConfig;
81     for (auto&& key : supportedConfigKeys) {
82         auto itKey = config.find(key);
83         if (config.end() != itKey) {
84             supportedConfig[key] = itKey->second;
85         }
86     }
87     return supportedConfig;
88 }
89
90
91 InferenceEngine::InferencePlugin Engine::GetDevicePlugin(const std::string& device) const {
92     InferenceEngine::InferencePlugin plugin;
93     if (nullptr == _core) {
94         IE_SUPPRESS_DEPRECATED_START
95         // try to create plugin
96         PluginDispatcher dispatcher({""});
97         plugin = dispatcher.getPluginByDevice(device);
98         IE_SUPPRESS_DEPRECATED_END
99     } else {
100         plugin = InferencePlugin{_core->GetPluginByName(device)};
101     }
102
103     try {
104         for (auto&& ext : _extensions) {
105             plugin.AddExtension(ext);
106         }
107     } catch (InferenceEngine::details::InferenceEngineException &) {}
108
109     plugin.SetConfig(GetSupportedConfig(_config, plugin));
110
111     if (nullptr != _errorListener) {
112         static_cast<InferenceEnginePluginPtr>(plugin)->SetLogCallback(*_errorListener);
113     }
114
115     return plugin;
116 }
117
118 Engine::Plugins Engine::GetDevicePlugins(const std::string& targetFallback) const {
119     auto devices = InferenceEngine::DeviceIDParser::getHeteroDevices(targetFallback);
120     Engine::Plugins plugins = _plugins;
121     for (auto&& device : devices) {
122         auto itPlugin = plugins.find(device);
123         if (plugins.end() == itPlugin) {
124             plugins[device] = GetDevicePlugin(device);
125         }
126     }
127     return plugins;
128 }
129
130 Engine::Plugins Engine::GetDevicePlugins(const std::string& targetFallback) {
131     _plugins = const_cast<const Engine*>(this)->GetDevicePlugins(targetFallback);
132     return _plugins;
133 }
134
135 void Engine::SetConfig(const Configs &configs) {
136     for (auto&& config : configs) {
137         _config[config.first] = config.second;
138     }
139
140     for (auto&& plugin : _plugins) {
141         plugin.second.SetConfig(GetSupportedConfig(configs, plugin.second));
142     }
143 }
144
145 void Engine::AddExtension(InferenceEngine::IExtensionPtr extension) {
146     _extensions.emplace_back(std::move(extension));
147     try {
148         for (auto&& plugin : _plugins) {
149             plugin.second.AddExtension(std::move(extension));
150         }
151     } catch (InferenceEngine::details::InferenceEngineException &) {}
152 }
153
154 HeteroLayerColorer::HeteroLayerColorer(const std::vector<std::string>& devices) {
155     static const std::vector<std::string> colors = {"#5A5DF0", "#20F608", "#F1F290", "#11F110"};
156     for (auto&& device : devices) {
157         deviceColorMap[device] = colors[std::distance(&device, devices.data()) % colors.size()];
158     }
159 }
160
161 void HeteroLayerColorer::operator()(const CNNLayerPtr layer,
162                 ordered_properties &printed_properties,
163                 ordered_properties &node_properties) {
164     auto device = layer->affinity;
165     printed_properties.insert(printed_properties.begin(), std::make_pair("device", device));
166     node_properties.emplace_back("fillcolor", deviceColorMap[device]);
167 }
168
169 void Engine::SetAffinity(InferenceEngine::ICNNNetwork &network, const Configs &config) {
170     auto it = config.find("TARGET_FALLBACK");
171     if (it == config.end()) {
172         it = _config.find("TARGET_FALLBACK");
173
174         if (it == _config.end()) {
175             THROW_IE_EXCEPTION << "The 'TARGET_FALLBACK' option was not defined for heterogeneous plugin";
176         }
177     }
178
179     GetDevicePlugins(it->second);
180     SetConfig(config);
181     QueryNetworkResult qr;
182     QueryNetwork(network, config, qr);
183
184     details::CNNNetworkIterator i(const_cast<ICNNNetwork *>(&network));
185     while (i != details::CNNNetworkIterator()) {
186         CNNLayer::Ptr layer = *i;
187         auto it = qr.supportedLayersMap.find(layer->name);
188         if (it != qr.supportedLayersMap.end()) {
189             layer->affinity = it->second;
190         }
191         i++;
192     }
193
194     if ("YES" == _config[HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)]) {
195         std::unordered_set<std::string> devicesSet;
196         details::CNNNetworkIterator i(&network);
197         while (i != details::CNNNetworkIterator()) {
198             CNNLayer::Ptr layer = *i;
199             if (!layer->affinity.empty()) {
200                 devicesSet.insert(layer->affinity);
201             }
202             i++;
203         }
204         std::vector<std::string> devices{std::begin(devicesSet), std::end(devicesSet)};
205         std::stringstream stream(std::stringstream::out);
206         stream << "hetero_affinity_" << network.getName() << ".dot";
207
208         std::ofstream file(stream.str());
209
210         saveGraphToDot(network, file, HeteroLayerColorer{devices});
211     }
212 }
213
214 void Engine::SetLogCallback(IErrorListener &listener) {
215     _errorListener = &listener;
216     for (auto&& plugin : _plugins) {
217         static_cast<InferenceEnginePluginPtr>(plugin.second)->SetLogCallback(*_errorListener);
218     }
219 }
220
221 void Engine::QueryNetwork(const ICNNNetwork &network, const Configs& config, QueryNetworkResult &qr) const {
222     auto it = config.find("TARGET_FALLBACK");
223     if (it == config.end()) {
224         it = _config.find("TARGET_FALLBACK");
225
226         if (it == _config.end()) {
227             THROW_IE_EXCEPTION << "The 'TARGET_FALLBACK' option was not defined for heterogeneous plugin";
228         }
229     }
230
231     Plugins plugins = GetDevicePlugins(it->second);
232
233     qr.rc = StatusCode::OK;
234
235     std::map<std::string, QueryNetworkResult> queryResults;
236     // go over devices, create appropriate plugins and
237     for (auto&& value : plugins) {
238         auto& device = value.first;
239         auto& plugin = value.second;
240         QueryNetworkResult r;
241         plugin.QueryNetwork(network, GetSupportedConfig(config, plugin), r);
242         queryResults[device] = r;
243     }
244
245     //  WARNING: Here is devices with user set priority
246     auto falbackDevices = InferenceEngine::DeviceIDParser::getHeteroDevices(it->second);
247
248     details::CNNNetworkIterator i(const_cast<ICNNNetwork *>(&network));
249     while (i != details::CNNNetworkIterator()) {
250         CNNLayer::Ptr layer = *i;
251         for (auto&& device : falbackDevices) {
252             auto& deviceQueryResult = queryResults[device];
253             if (deviceQueryResult.supportedLayersMap.find(layer->name) != deviceQueryResult.supportedLayersMap.end()) {
254                 qr.supportedLayersMap[layer->name] = device;
255                 break;
256             }
257         }
258         i++;
259     }
260 }
261
262 Parameter Engine::GetMetric(const std::string& name, const std::map<std::string, Parameter> & options) const {
263     if (METRIC_KEY(SUPPORTED_METRICS) == name) {
264         IE_SET_METRIC_RETURN(SUPPORTED_METRICS, std::vector<std::string>{
265             METRIC_KEY(SUPPORTED_METRICS),
266             METRIC_KEY(SUPPORTED_CONFIG_KEYS)});
267     } else if (METRIC_KEY(SUPPORTED_CONFIG_KEYS) == name) {
268         IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, std::vector<std::string>{
269             HETERO_CONFIG_KEY(DUMP_GRAPH_DOT),
270             "TARGET_FALLBACK",
271             CONFIG_KEY(EXCLUSIVE_ASYNC_REQUESTS)});
272     } else {
273         THROW_IE_EXCEPTION << "Unsupported Plugin metric: " << name;
274     }
275 }
276
277 Parameter Engine::GetConfig(const std::string& name, const std::map<std::string, Parameter> & options) const {
278     if (name == HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)) {
279         auto it = _config.find(HETERO_CONFIG_KEY(DUMP_GRAPH_DOT));
280         IE_ASSERT(it != _config.end());
281         bool dump = it->second == YES;
282         return { dump };
283     } else {
284         THROW_IE_EXCEPTION << "Unsupported config key: " << name;
285     }
286 }
287
288 INFERENCE_PLUGIN_API(InferenceEngine::StatusCode) CreatePluginEngine(
289         InferenceEngine::IInferencePlugin *&plugin,
290         InferenceEngine::ResponseDesc *resp) noexcept {
291     try {
292         plugin = make_ie_compatible_plugin({2, 1, CI_BUILD_NUMBER, "heteroPlugin"},
293                                            std::make_shared<Engine>());
294         return OK;
295     }
296     catch (std::exception &ex) {
297         return DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
298     }
299 }