Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ie_graph_splitter.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include "ie_graph_splitter.hpp"
6
7 #include <cassert>
8 #include <unordered_map>
9 #include <unordered_set>
10 #include <utility>
11 #include <vector>
12 #include <string>
13
14 #include <ade_util.hpp>
15
16 #include <ade/typed_graph.hpp>
17 #include <ade/helpers/subgraphs.hpp>
18
19 #include <ade/util/filter_range.hpp>
20 #include <ade/util/iota_range.hpp>
21
22 namespace InferenceEngine {
23
24 namespace {
25 class ISplitChecker {
26 public:
27     struct GraphSelectionResult final {
28         static const constexpr std::size_t NoGraph
29             = static_cast<std::size_t>(-1);
30
31         std::size_t selectedGraph = NoGraph;
32         bool continueSelect = false;
33     };
34
35     virtual ~ISplitChecker() = default;
36     virtual GraphSelectionResult selectSubgraph(
37             const std::vector<LayersSet>& subgraphs) = 0;
38 };
39
40 class DefaultSplitChecker : public ISplitChecker {
41 public:
42     // ISplitChecker interface
43     GraphSelectionResult selectSubgraph(const std::vector<LayersSet>& subgraphs) override;
44 };
45 }  // namespace
46
47 std::vector<LayersSet> splitGraph(ICNNNetwork& network,
48         const std::vector<std::string>& plugins) {
49     assert(!plugins.empty());
50     ade::Graph gr;
51     ade::TypedGraph<CNNLayerMetadata> tgr(gr);
52
53     std::vector<LayersSet> tempSubgraphs;
54     LayersSet tempSet1;
55     LayersSet tempSet2;
56
57     translateNetworkToAde(gr, network);
58     std::size_t currentChecker = 0;
59
60     DefaultSplitChecker checker;
61
62     auto getChecker = [&]() {
63         assert(currentChecker < plugins.size());
64         return &checker;
65     };
66
67     auto getAffinity = [&]()->const std::string& {
68         assert(currentChecker < plugins.size());
69         return plugins[currentChecker];
70     };
71
72     auto nodes = gr.nodes();
73     ade::subgraphs::NodesSet availableNodes(nodes.begin(), nodes.end());
74     std::vector<LayersSet> finalSubgraphs;
75     ade::SubgraphSelfReferenceChecker cycleChecker(nodes);
76     while (!availableNodes.empty()) {
77         auto subgraphs = ade::selectSubgraphs(
78                              ade::util::filter(ade::util::toRange(availableNodes),
79                                           [&](const ade::NodeHandle& node) {
80             assert(nullptr != node);
81             auto layer = tgr.metadata(node).get<CNNLayerMetadata>().layer;
82             assert(nullptr != layer);
83             return layer->affinity == getAffinity();
84         }),
85                              [&](
86                              const ade::EdgeHandle& edge,
87                              ade::SubgraphMergeDirection dir) {
88             assert(nullptr != edge);
89             auto dstNode = ade::getDstMergeNode(edge, dir);
90             assert(nullptr != dstNode);
91             if (!ade::util::contains(availableNodes, dstNode)) {
92                 return false;
93             }
94             auto srcNode = ade::getSrcMergeNode(edge, dir);
95             assert(nullptr != srcNode);
96             auto srcLayer = tgr.metadata(srcNode).get<CNNLayerMetadata>().layer;
97             auto dstLayer = tgr.metadata(dstNode).get<CNNLayerMetadata>().layer;
98             assert(nullptr != srcLayer);
99             assert(nullptr != dstLayer);
100             return srcLayer->affinity == dstLayer->affinity;
101         },
102                              [&](
103                              const ade::subgraphs::NodesSet& acceptedNodes,
104                              const ade::subgraphs::NodesSet& rejectedNodes) {
105             if (cycleChecker(acceptedNodes, rejectedNodes)) {
106                 return false;
107             }
108             return true;
109         });
110
111         if (!subgraphs.empty()) {
112             if (plugins.size() == currentChecker) {
113                 THROW_IE_EXCEPTION << "Some nodes weren't assigned to plugin";
114             }
115
116             tempSubgraphs.clear();
117             for (auto&& subgraph : subgraphs) {
118                 assert(!subgraph.empty());
119                 tempSet1.clear();
120                 for (auto&& node : subgraph) {
121                     assert(nullptr != node);
122                     auto layer = tgr.metadata(node).get<CNNLayerMetadata>().layer;
123                     assert(nullptr != layer);
124                     tempSet1.insert(layer);
125                 }
126                 tempSubgraphs.emplace_back(std::move(tempSet1));
127             }
128             auto result = getChecker()->selectSubgraph(tempSubgraphs);
129             const auto selected = result.selectedGraph;
130             if (ISplitChecker::GraphSelectionResult::NoGraph !=
131                     selected) {
132                 assert(selected < subgraphs.size());
133                 finalSubgraphs.emplace_back(std::move(tempSubgraphs[selected]));
134
135                 for (auto&& node : subgraphs[selected]) {
136                     availableNodes.erase(node);
137                 }
138
139                 if (result.continueSelect) {
140                     continue;
141                 }
142             }
143         }
144         ++currentChecker;
145     }
146
147     return finalSubgraphs;
148 }
149
150 ISplitChecker::GraphSelectionResult DefaultSplitChecker::selectSubgraph(
151         const std::vector<LayersSet>& subgraphs) {
152     assert(!subgraphs.empty());
153     std::size_t index = 0;
154     auto maxSize = subgraphs[0].size();
155     for (auto i : ade::util::iota(std::size_t(1), subgraphs.size())) {
156         auto size = subgraphs[i].size();
157         if (size > maxSize) {
158             index = 1;
159             maxSize = size;
160         }
161     }
162     GraphSelectionResult ret;
163     ret.selectedGraph = index;
164     ret.continueSelect = true;
165     return ret;
166 }
167
168 namespace {
169 struct SubgraphDesc {
170     std::size_t topoIndex = static_cast<std::size_t>(-1);
171     std::unordered_set<std::size_t> dependsOn;
172 };
173
174 void topoVisitSubgraph(std::vector<SubgraphDesc>& subgraphs,
175                        SubgraphDesc& subgraph,
176                        std::size_t& topoIndex) {
177     if (subgraph.topoIndex != static_cast<std::size_t>(-1)) {
178         assert(subgraph.topoIndex < topoIndex);
179         return;
180     }
181
182     for (auto&& dep : subgraph.dependsOn) {
183         topoVisitSubgraph(subgraphs, subgraphs[dep], topoIndex);
184     }
185     subgraph.topoIndex = topoIndex;
186     ++topoIndex;
187 }
188 }  // namespace
189
190 void sortSubgraphs(std::vector<LayersSet>& subgraphs) {
191     std::vector<SubgraphDesc> descs(subgraphs.size());
192
193     for (auto i : ade::util::iota(subgraphs.size())) {
194         auto& subgraph = subgraphs[i];
195         assert(!subgraph.empty());
196         for (auto&& layer : subgraph) {
197             assert(nullptr != layer);
198             for (auto&& dataIt : layer->insData) {
199                 auto data = dataIt.lock();
200                 assert(nullptr != data);
201                 auto prevLayer = data->creatorLayer.lock();
202                 if (nullptr != prevLayer) {
203                     for (auto j : ade::util::iota(subgraphs.size())) {
204                         if (i != j) {
205                             if (ade::util::contains(subgraphs[j], prevLayer)) {
206                                 descs[i].dependsOn.insert(j);
207                                 break;
208                             }
209                         }
210                     }
211                 }
212             }
213         }
214     }
215
216     {
217         std::size_t topoIndex = 0;
218         for (auto&& desc : descs) {
219             topoVisitSubgraph(descs, desc, topoIndex);
220         }
221         assert(subgraphs.size() == topoIndex);
222     }
223
224     std::vector<LayersSet> ret(subgraphs.size());
225     for (auto i : ade::util::iota(subgraphs.size())) {
226         assert(i < descs.size());
227         auto& desc = descs[i];
228         auto topoIndex = desc.topoIndex;
229         assert(topoIndex != static_cast<std::size_t>(-1));
230         assert(topoIndex < ret.size());
231         assert(!subgraphs[i].empty());
232         ret[topoIndex] = std::move(subgraphs[i]);
233     }
234     subgraphs = std::move(ret);
235 }
236
237 }  // namespace InferenceEngine