1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include "hetero_ade_util.hpp"
7 #include <unordered_map>
10 #include <ie_icnn_network.hpp>
11 #include <ie_util_internal.hpp>
12 #include <ie_layers.h>
14 #include <ade/util/algorithm.hpp>
15 #include <ade/graph.hpp>
16 #include <ade/typed_graph.hpp>
18 namespace InferenceEngine {
20 using VisitedLayersMap = std::unordered_map<CNNLayer::Ptr, ade::NodeHandle>;
21 using TGraph = ade::TypedGraph<CNNLayerMetadata>;
23 void translateVisitLayer(VisitedLayersMap& visited,
25 const ade::NodeHandle& prevNode,
26 const CNNLayer::Ptr& layer) {
27 assert(nullptr != layer);;
28 assert(!ade::util::contains(visited, layer));
29 auto node = gr.createNode();
30 gr.metadata(node).set(CNNLayerMetadata{layer});
31 if (nullptr != prevNode) {
32 gr.link(prevNode, node);
34 visited.insert({layer, node});
35 for (auto&& data : layer->outData) {
36 for (auto&& layerIt : data->getInputTo()) {
37 auto nextLayer = layerIt.second;
38 auto it = visited.find(nextLayer);
39 if (visited.end() == it) {
40 translateVisitLayer(visited, gr, node, nextLayer);
42 gr.link(node, it->second);
49 void translateNetworkToAde(ade::Graph& gr, ICNNNetwork& network) {
51 VisitedLayersMap visited;
52 for (auto& data : getRootDataObjects(network)) {
53 assert(nullptr != data);
54 for (auto& layerIt : data->getInputTo()) {
55 auto layer = layerIt.second;
56 assert(nullptr != layer);
57 if (!ade::util::contains(visited, layer)) {
58 translateVisitLayer(visited, tgr, nullptr, layer);
64 const char* CNNLayerMetadata::name() {
65 return "CNNLayerMetadata";
68 } // namespace InferenceEngine