1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "ie_graph_splitter.hpp"
8 #include <unordered_map>
9 #include <unordered_set>
14 #include <ade_util.hpp>
16 #include <ade/typed_graph.hpp>
17 #include <ade/helpers/subgraphs.hpp>
19 #include <ade/util/filter_range.hpp>
20 #include <ade/util/iota_range.hpp>
22 namespace InferenceEngine {
27 struct GraphSelectionResult final {
28 static const constexpr std::size_t NoGraph
29 = static_cast<std::size_t>(-1);
31 std::size_t selectedGraph = NoGraph;
32 bool continueSelect = false;
35 virtual ~ISplitChecker() = default;
36 virtual GraphSelectionResult selectSubgraph(
37 const std::vector<LayersSet>& subgraphs) = 0;
40 class DefaultSplitChecker : public ISplitChecker {
42 // ISplitChecker interface
43 GraphSelectionResult selectSubgraph(const std::vector<LayersSet>& subgraphs) override;
47 std::vector<LayersSet> splitGraph(ICNNNetwork& network,
48 const std::vector<std::string>& plugins) {
49 assert(!plugins.empty());
51 ade::TypedGraph<CNNLayerMetadata> tgr(gr);
53 std::vector<LayersSet> tempSubgraphs;
57 translateNetworkToAde(gr, network);
58 std::size_t currentChecker = 0;
60 DefaultSplitChecker checker;
62 auto getChecker = [&]() {
63 assert(currentChecker < plugins.size());
67 auto getAffinity = [&]()->const std::string& {
68 assert(currentChecker < plugins.size());
69 return plugins[currentChecker];
72 auto nodes = gr.nodes();
73 ade::subgraphs::NodesSet availableNodes(nodes.begin(), nodes.end());
74 std::vector<LayersSet> finalSubgraphs;
75 ade::SubgraphSelfReferenceChecker cycleChecker(nodes);
76 while (!availableNodes.empty()) {
77 auto subgraphs = ade::selectSubgraphs(
78 ade::util::filter(ade::util::toRange(availableNodes),
79 [&](const ade::NodeHandle& node) {
80 assert(nullptr != node);
81 auto layer = tgr.metadata(node).get<CNNLayerMetadata>().layer;
82 assert(nullptr != layer);
83 return layer->affinity == getAffinity();
86 const ade::EdgeHandle& edge,
87 ade::SubgraphMergeDirection dir) {
88 assert(nullptr != edge);
89 auto dstNode = ade::getDstMergeNode(edge, dir);
90 assert(nullptr != dstNode);
91 if (!ade::util::contains(availableNodes, dstNode)) {
94 auto srcNode = ade::getSrcMergeNode(edge, dir);
95 assert(nullptr != srcNode);
96 auto srcLayer = tgr.metadata(srcNode).get<CNNLayerMetadata>().layer;
97 auto dstLayer = tgr.metadata(dstNode).get<CNNLayerMetadata>().layer;
98 assert(nullptr != srcLayer);
99 assert(nullptr != dstLayer);
100 return srcLayer->affinity == dstLayer->affinity;
103 const ade::subgraphs::NodesSet& acceptedNodes,
104 const ade::subgraphs::NodesSet& rejectedNodes) {
105 if (cycleChecker(acceptedNodes, rejectedNodes)) {
111 if (!subgraphs.empty()) {
112 if (plugins.size() == currentChecker) {
113 THROW_IE_EXCEPTION << "Some nodes weren't assigned to plugin";
116 tempSubgraphs.clear();
117 for (auto&& subgraph : subgraphs) {
118 assert(!subgraph.empty());
120 for (auto&& node : subgraph) {
121 assert(nullptr != node);
122 auto layer = tgr.metadata(node).get<CNNLayerMetadata>().layer;
123 assert(nullptr != layer);
124 tempSet1.insert(layer);
126 tempSubgraphs.emplace_back(std::move(tempSet1));
128 auto result = getChecker()->selectSubgraph(tempSubgraphs);
129 const auto selected = result.selectedGraph;
130 if (ISplitChecker::GraphSelectionResult::NoGraph !=
132 assert(selected < subgraphs.size());
133 finalSubgraphs.emplace_back(std::move(tempSubgraphs[selected]));
135 for (auto&& node : subgraphs[selected]) {
136 availableNodes.erase(node);
139 if (result.continueSelect) {
147 return finalSubgraphs;
150 ISplitChecker::GraphSelectionResult DefaultSplitChecker::selectSubgraph(
151 const std::vector<LayersSet>& subgraphs) {
152 assert(!subgraphs.empty());
153 std::size_t index = 0;
154 auto maxSize = subgraphs[0].size();
155 for (auto i : ade::util::iota(std::size_t(1), subgraphs.size())) {
156 auto size = subgraphs[i].size();
157 if (size > maxSize) {
162 GraphSelectionResult ret;
163 ret.selectedGraph = index;
164 ret.continueSelect = true;
169 struct SubgraphDesc {
170 std::size_t topoIndex = static_cast<std::size_t>(-1);
171 std::unordered_set<std::size_t> dependsOn;
174 void topoVisitSubgraph(std::vector<SubgraphDesc>& subgraphs,
175 SubgraphDesc& subgraph,
176 std::size_t& topoIndex) {
177 if (subgraph.topoIndex != static_cast<std::size_t>(-1)) {
178 assert(subgraph.topoIndex < topoIndex);
182 for (auto&& dep : subgraph.dependsOn) {
183 topoVisitSubgraph(subgraphs, subgraphs[dep], topoIndex);
185 subgraph.topoIndex = topoIndex;
190 void sortSubgraphs(std::vector<LayersSet>& subgraphs) {
191 std::vector<SubgraphDesc> descs(subgraphs.size());
193 for (auto i : ade::util::iota(subgraphs.size())) {
194 auto& subgraph = subgraphs[i];
195 assert(!subgraph.empty());
196 for (auto&& layer : subgraph) {
197 assert(nullptr != layer);
198 for (auto&& dataIt : layer->insData) {
199 auto data = dataIt.lock();
200 assert(nullptr != data);
201 auto prevLayer = data->creatorLayer.lock();
202 if (nullptr != prevLayer) {
203 for (auto j : ade::util::iota(subgraphs.size())) {
205 if (ade::util::contains(subgraphs[j], prevLayer)) {
206 descs[i].dependsOn.insert(j);
217 std::size_t topoIndex = 0;
218 for (auto&& desc : descs) {
219 topoVisitSubgraph(descs, desc, topoIndex);
221 assert(subgraphs.size() == topoIndex);
224 std::vector<LayersSet> ret(subgraphs.size());
225 for (auto i : ade::util::iota(subgraphs.size())) {
226 assert(i < descs.size());
227 auto& desc = descs[i];
228 auto topoIndex = desc.topoIndex;
229 assert(topoIndex != static_cast<std::size_t>(-1));
230 assert(topoIndex < ret.size());
231 assert(!subgraphs[i].empty());
232 ret[topoIndex] = std::move(subgraphs[i]);
234 subgraphs = std::move(ret);
237 } // namespace InferenceEngine