Publishing R3
[platform/upstream/dldt.git] / inference-engine / src / inference_engine / ade_util.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "ade_util.hpp"
7
8 #include <unordered_map>
9 #include <utility>
10
11 #include <ie_icnn_network.hpp>
12 #include <ie_util_internal.hpp>
13 #include <ie_layers.h>
14
15 #include <util/algorithm.hpp>
16 #include <graph.hpp>
17 #include <typed_graph.hpp>
18
19 namespace InferenceEngine {
20 namespace {
21 using VisitedLayersMap = std::unordered_map<CNNLayer::Ptr, ade::NodeHandle>;
22 using TGraph = ade::TypedGraph<CNNLayerMetadata>;
23
24 void translateVisitLayer(VisitedLayersMap& visited,
25                 TGraph& gr,
26                 const ade::NodeHandle& prevNode,
27                 const CNNLayer::Ptr& layer) {
28     assert(nullptr != layer);;
29     assert(!util::contains(visited, layer));
30     auto node = gr.createNode();
31     gr.metadata(node).set(CNNLayerMetadata{layer});
32     if (nullptr != prevNode) {
33         gr.link(prevNode, node);
34     }
35     visited.insert({layer, node});
36     for (auto&& data : layer->outData) {
37         for (auto&& layerIt : data->inputTo) {
38             auto nextLayer = layerIt.second;
39             auto it = visited.find(nextLayer);
40             if (visited.end() == it) {
41                 translateVisitLayer(visited, gr, node, nextLayer);
42             } else {
43                 gr.link(node, it->second);
44             }
45         }
46     }
47 }
48 }  // namespace
49
50 void translateNetworkToAde(ade::Graph& gr, ICNNNetwork& network) {
51     TGraph tgr(gr);
52     VisitedLayersMap visited;
53     for (auto& data : getRootDataObjects(network)) {
54         assert(nullptr != data);
55         for (auto& layerIt : data->getInputTo()) {
56             auto layer = layerIt.second;
57             assert(nullptr != layer);
58             if (!util::contains(visited, layer)) {
59                 translateVisitLayer(visited, tgr, nullptr, layer);
60             }
61         }
62     }
63 }
64
65 const char* CNNLayerMetadata::name() {
66     return "CNNLayerMetadata";
67 }
68
69 }  // namespace InferenceEngine