f97bbe7b642f2287dad8bd194a7105c2cdd01368
[platform/upstream/dldt.git] / inference-engine / src / hetero_plugin / fallback_policy.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "fallback_policy.h"
7 #include "hetero_device_loader.h"
8 #include "details/ie_cnn_network_iterator.hpp"
9 #include "ie_layers.h"
10 #include "ie_util_internal.hpp"
11 #include <fstream>
12 #include <vector>
13 #include <memory>
14
15 using namespace InferenceEngine;
16
17 void dla_layer_colorer(const CNNLayerPtr layer,
18                        ordered_properties &printed_properties,
19                        ordered_properties &node_properties) {
20     printed_properties.insert(printed_properties.begin(),
21                               std::pair<std::string, std::string>("device", layer->affinity));
22     if (layer->affinity == "CPU") {
23         node_properties.emplace_back("fillcolor", "#5A5DF0");
24     } else if (layer->affinity == "FPGA") {
25         node_properties.emplace_back("fillcolor", "#20F608");
26     } else if (layer->affinity == "GPU") {
27         node_properties.emplace_back("fillcolor", "#F1F290");
28     } else {
29         node_properties.emplace_back("fillcolor", "#11F110");
30     }
31 }
32
33
34 FallbackPolicy::FallbackPolicy(std::map<std::string, InferenceEngine::IHeteroDeviceLoader::Ptr> &deviceLoaders,
35                                bool dumpDotFile) :
36     _deviceLoaders(deviceLoaders),
37     _dumpDotFile(dumpDotFile) {
38 }
39
40 void FallbackPolicy::init(const std::string &config, const std::map<std::string, std::string> &allConfigs,
41                           const std::vector<InferenceEngine::IExtensionPtr> &extensions) {
42     if (config.empty()) {
43         THROW_IE_EXCEPTION << "Cannot set affinity according to fallback policy because the order of devices was not initialized";
44     }
45     // parsing the string and splitting to tokens
46     std::string::size_type i = 0;
47     std::string::size_type idelimeter;
48     while ((idelimeter = config.find(',', i)) != std::string::npos) {
49         _fallbackDevices.push_back(config.substr(i, idelimeter - i));
50         i = idelimeter + 1;
51     }
52     _fallbackDevices.push_back(config.substr(i, config.length() - i));
53
54     for (auto d : _fallbackDevices) {
55         if (_deviceLoaders.find(d) == _deviceLoaders.end()) {
56             IHeteroDeviceLoader::Ptr loader;
57             loader = std::make_shared<HeteroDeviceLoader>(d);
58             HeteroDeviceLoader *pdl = dynamic_cast<HeteroDeviceLoader *>(loader.get());
59             pdl->initConfigs(allConfigs, extensions);
60             _deviceLoaders[d] = loader;
61         }
62     }
63 }
64
65 void FallbackPolicy::setAffinity(const std::map<std::string, std::string>& config, ICNNNetwork& network) {
66     std::map<std::string, QueryNetworkResult> queryResults;
67     // go oger devices, create appropriate plugins and
68     for (const auto &i : _fallbackDevices) {
69         QueryNetworkResult r;
70         _deviceLoaders[i]->QueryNetwork(i, network, config, r);
71         queryResults[i] = r;
72     }
73
74     details::CNNNetworkIterator i(const_cast<ICNNNetwork *>(&network));
75     while (i != details::CNNNetworkIterator()) {
76         CNNLayer::Ptr layer = *i;
77         for (auto &&j : _fallbackDevices) {
78             auto &qr = queryResults[j];
79             if (qr.supportedLayers.find(layer->name) != qr.supportedLayers.end()) {
80                 layer->affinity = j;
81                 break;
82             }
83         }
84         i++;
85     }
86
87     if (_dumpDotFile) {
88         std::stringstream stream(std::stringstream::out);
89         stream << "hetero_affinity_" << network.getName() << ".dot";
90
91         std::ofstream file(stream.str().c_str());
92         saveGraphToDot(network, file, dla_layer_colorer);
93     }
94 }