1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ie_metric_helpers.hpp"
6 #include "hetero_plugin.hpp"
7 #include "ie_util_internal.hpp"
14 #include <unordered_set>
15 #include "ie_plugin_config.hpp"
16 #include "hetero/hetero_plugin_config.hpp"
17 #include <cpp_interfaces/base/ie_plugin_base.hpp>
18 #include "hetero_executable_network.hpp"
19 #include "cpp_interfaces/base/ie_inference_plugin_api.hpp"
21 using namespace InferenceEngine;
22 using namespace InferenceEngine::PluginConfigParams;
23 using namespace InferenceEngine::HeteroConfigParams;
24 using namespace HeteroPlugin;
27 static Version heteroPluginDescription = {
28 {2, 1}, // plugin API version
30 "heteroPlugin" // plugin description message
33 void Engine::GetVersion(const Version *&versionInfo)noexcept {
34 versionInfo = &heteroPluginDescription;
38 _pluginName = "HETERO";
39 _config[InferenceEngine::PluginConfigParams::KEY_EXCLUSIVE_ASYNC_REQUESTS] = "YES";
40 _config[HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)] = NO;
43 InferenceEngine::ExecutableNetworkInternal::Ptr Engine::LoadExeNetworkImpl(const ICore* core,
44 InferenceEngine::ICNNNetwork& network,
45 const Configs& config) {
46 // TODO(amalyshe) do we need here verification of input precisions?
50 // we must not override the parameter, but need to copy everything from plugin config
51 for (auto && c : _config) {
52 if (tconfig.find(c.first) == tconfig.end()) {
53 tconfig[c.first] = c.second;
57 return std::make_shared<HeteroExecutableNetwork>(network, tconfig, this);
62 IInferencePluginAPI * getInferencePluginAPIInterface(IInferencePlugin * iplugin) {
63 return dynamic_cast<IInferencePluginAPI *>(iplugin);
66 IInferencePluginAPI * getInferencePluginAPIInterface(InferenceEnginePluginPtr iplugin) {
67 return getInferencePluginAPIInterface(static_cast<IInferencePlugin *>(iplugin.operator->()));
70 IInferencePluginAPI * getInferencePluginAPIInterface(InferencePlugin plugin) {
71 return getInferencePluginAPIInterface(static_cast<InferenceEnginePluginPtr>(plugin));
77 Engine::Configs Engine::GetSupportedConfig(const Engine::Configs& config, const InferenceEngine::InferencePlugin& plugin) {
78 auto pluginApi = getInferencePluginAPIInterface(plugin);
79 std::vector<std::string> supportedConfigKeys = pluginApi->GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), {});
80 Engine::Configs supportedConfig;
81 for (auto&& key : supportedConfigKeys) {
82 auto itKey = config.find(key);
83 if (config.end() != itKey) {
84 supportedConfig[key] = itKey->second;
87 return supportedConfig;
91 InferenceEngine::InferencePlugin Engine::GetDevicePlugin(const std::string& device) const {
92 InferenceEngine::InferencePlugin plugin;
93 if (nullptr == _core) {
94 IE_SUPPRESS_DEPRECATED_START
95 // try to create plugin
96 PluginDispatcher dispatcher({""});
97 plugin = dispatcher.getPluginByDevice(device);
98 IE_SUPPRESS_DEPRECATED_END
100 plugin = InferencePlugin{_core->GetPluginByName(device)};
104 for (auto&& ext : _extensions) {
105 plugin.AddExtension(ext);
107 } catch (InferenceEngine::details::InferenceEngineException &) {}
109 plugin.SetConfig(GetSupportedConfig(_config, plugin));
111 if (nullptr != _errorListener) {
112 static_cast<InferenceEnginePluginPtr>(plugin)->SetLogCallback(*_errorListener);
118 Engine::Plugins Engine::GetDevicePlugins(const std::string& targetFallback) const {
119 auto devices = InferenceEngine::DeviceIDParser::getHeteroDevices(targetFallback);
120 Engine::Plugins plugins = _plugins;
121 for (auto&& device : devices) {
122 auto itPlugin = plugins.find(device);
123 if (plugins.end() == itPlugin) {
124 plugins[device] = GetDevicePlugin(device);
130 Engine::Plugins Engine::GetDevicePlugins(const std::string& targetFallback) {
131 _plugins = const_cast<const Engine*>(this)->GetDevicePlugins(targetFallback);
135 void Engine::SetConfig(const Configs &configs) {
136 for (auto&& config : configs) {
137 _config[config.first] = config.second;
140 for (auto&& plugin : _plugins) {
141 plugin.second.SetConfig(GetSupportedConfig(configs, plugin.second));
145 void Engine::AddExtension(InferenceEngine::IExtensionPtr extension) {
146 _extensions.emplace_back(std::move(extension));
148 for (auto&& plugin : _plugins) {
149 plugin.second.AddExtension(std::move(extension));
151 } catch (InferenceEngine::details::InferenceEngineException &) {}
154 HeteroLayerColorer::HeteroLayerColorer(const std::vector<std::string>& devices) {
155 static const std::vector<std::string> colors = {"#5A5DF0", "#20F608", "#F1F290", "#11F110"};
156 for (auto&& device : devices) {
157 deviceColorMap[device] = colors[std::distance(&device, devices.data()) % colors.size()];
161 void HeteroLayerColorer::operator()(const CNNLayerPtr layer,
162 ordered_properties &printed_properties,
163 ordered_properties &node_properties) {
164 auto device = layer->affinity;
165 printed_properties.insert(printed_properties.begin(), std::make_pair("device", device));
166 node_properties.emplace_back("fillcolor", deviceColorMap[device]);
169 void Engine::SetAffinity(InferenceEngine::ICNNNetwork &network, const Configs &config) {
170 auto it = config.find("TARGET_FALLBACK");
171 if (it == config.end()) {
172 it = _config.find("TARGET_FALLBACK");
174 if (it == _config.end()) {
175 THROW_IE_EXCEPTION << "The 'TARGET_FALLBACK' option was not defined for heterogeneous plugin";
179 GetDevicePlugins(it->second);
181 QueryNetworkResult qr;
182 QueryNetwork(network, config, qr);
184 details::CNNNetworkIterator i(const_cast<ICNNNetwork *>(&network));
185 while (i != details::CNNNetworkIterator()) {
186 CNNLayer::Ptr layer = *i;
187 auto it = qr.supportedLayersMap.find(layer->name);
188 if (it != qr.supportedLayersMap.end()) {
189 layer->affinity = it->second;
194 if ("YES" == _config[HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)]) {
195 std::unordered_set<std::string> devicesSet;
196 details::CNNNetworkIterator i(&network);
197 while (i != details::CNNNetworkIterator()) {
198 CNNLayer::Ptr layer = *i;
199 if (!layer->affinity.empty()) {
200 devicesSet.insert(layer->affinity);
204 std::vector<std::string> devices{std::begin(devicesSet), std::end(devicesSet)};
205 std::stringstream stream(std::stringstream::out);
206 stream << "hetero_affinity_" << network.getName() << ".dot";
208 std::ofstream file(stream.str());
210 saveGraphToDot(network, file, HeteroLayerColorer{devices});
214 void Engine::SetLogCallback(IErrorListener &listener) {
215 _errorListener = &listener;
216 for (auto&& plugin : _plugins) {
217 static_cast<InferenceEnginePluginPtr>(plugin.second)->SetLogCallback(*_errorListener);
221 void Engine::QueryNetwork(const ICNNNetwork &network, const Configs& config, QueryNetworkResult &qr) const {
222 auto it = config.find("TARGET_FALLBACK");
223 if (it == config.end()) {
224 it = _config.find("TARGET_FALLBACK");
226 if (it == _config.end()) {
227 THROW_IE_EXCEPTION << "The 'TARGET_FALLBACK' option was not defined for heterogeneous plugin";
231 Plugins plugins = GetDevicePlugins(it->second);
233 qr.rc = StatusCode::OK;
235 std::map<std::string, QueryNetworkResult> queryResults;
236 // go over devices, create appropriate plugins and
237 for (auto&& value : plugins) {
238 auto& device = value.first;
239 auto& plugin = value.second;
240 QueryNetworkResult r;
241 plugin.QueryNetwork(network, GetSupportedConfig(config, plugin), r);
242 queryResults[device] = r;
245 // WARNING: Here is devices with user set priority
246 auto falbackDevices = InferenceEngine::DeviceIDParser::getHeteroDevices(it->second);
248 details::CNNNetworkIterator i(const_cast<ICNNNetwork *>(&network));
249 while (i != details::CNNNetworkIterator()) {
250 CNNLayer::Ptr layer = *i;
251 for (auto&& device : falbackDevices) {
252 auto& deviceQueryResult = queryResults[device];
253 if (deviceQueryResult.supportedLayersMap.find(layer->name) != deviceQueryResult.supportedLayersMap.end()) {
254 qr.supportedLayersMap[layer->name] = device;
262 Parameter Engine::GetMetric(const std::string& name, const std::map<std::string, Parameter> & options) const {
263 if (METRIC_KEY(SUPPORTED_METRICS) == name) {
264 IE_SET_METRIC_RETURN(SUPPORTED_METRICS, std::vector<std::string>{
265 METRIC_KEY(SUPPORTED_METRICS),
266 METRIC_KEY(SUPPORTED_CONFIG_KEYS)});
267 } else if (METRIC_KEY(SUPPORTED_CONFIG_KEYS) == name) {
268 IE_SET_METRIC_RETURN(SUPPORTED_CONFIG_KEYS, std::vector<std::string>{
269 HETERO_CONFIG_KEY(DUMP_GRAPH_DOT),
271 CONFIG_KEY(EXCLUSIVE_ASYNC_REQUESTS)});
273 THROW_IE_EXCEPTION << "Unsupported Plugin metric: " << name;
277 Parameter Engine::GetConfig(const std::string& name, const std::map<std::string, Parameter> & options) const {
278 if (name == HETERO_CONFIG_KEY(DUMP_GRAPH_DOT)) {
279 auto it = _config.find(HETERO_CONFIG_KEY(DUMP_GRAPH_DOT));
280 IE_ASSERT(it != _config.end());
281 bool dump = it->second == YES;
284 THROW_IE_EXCEPTION << "Unsupported config key: " << name;
288 INFERENCE_PLUGIN_API(InferenceEngine::StatusCode) CreatePluginEngine(
289 InferenceEngine::IInferencePlugin *&plugin,
290 InferenceEngine::ResponseDesc *resp) noexcept {
292 plugin = make_ie_compatible_plugin({2, 1, CI_BUILD_NUMBER, "heteroPlugin"},
293 std::make_shared<Engine>());
296 catch (std::exception &ex) {
297 return DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();