Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / include / details / ie_cnn_network_iterator.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  * @file ie_cnn_network_iterator.hpp
8  */
9 #pragma once
10 #include <utility>
11 #include <unordered_set>
12 #include <list>
13 #include <iterator>
14
15 #include "ie_locked_memory.hpp"
16 #include "ie_icnn_network.hpp"
17
18 namespace InferenceEngine {
19 namespace details {
20
21 /**
22  * @brief This class enables range loops for CNNNetwork objects
23  */
24 class CNNNetworkIterator {
25     std::unordered_set<CNNLayer*> visited;
26     std::list<CNNLayerPtr> nextLayersTovisit;
27     InferenceEngine::CNNLayerPtr currentLayer;
28     ICNNNetwork * network = nullptr;
29
30  public:
31     /**
32      * iterator trait definitions
33      */
34     typedef std::forward_iterator_tag iterator_category;
35     typedef CNNLayerPtr value_type;
36     typedef int         difference_type;
37     typedef CNNLayerPtr pointer;
38     typedef CNNLayerPtr reference;
39
40     /**
41      * @brief Default constructor
42      */
43     CNNNetworkIterator() = default;
44     /**
45      * @brief Constructor. Creates an iterator for specified CNNNetwork instance.
46      * @param network Network to iterate. Make sure the network object is not destroyed before iterator goes out of scope.
47      */
48     explicit CNNNetworkIterator(ICNNNetwork * network) {
49         InputsDataMap inputs;
50         network->getInputsInfo(inputs);
51         if (!inputs.empty()) {
52             auto & nextLayers = inputs.begin()->second->getInputData()->getInputTo();
53             if (!nextLayers.empty()) {
54                 currentLayer = nextLayers.begin()->second;
55                 nextLayersTovisit.push_back(currentLayer);
56                 visited.insert(currentLayer.get());
57             }
58         }
59     }
60
61     /**
62      * @brief Performs pre-increment 
63      * @return This CNNNetworkIterator instance
64      */
65     CNNNetworkIterator &operator++() {
66         currentLayer = next();
67         return *this;
68     }
69
70     /**
71      * @brief Performs post-increment.
72      * Implementation does not follow the std interface since only move semantics is used
73      */
74     void operator++(int) {
75         currentLayer = next();
76     }
77
78     /**
79      * @brief Checks if the given iterator is not equal to this one
80      * @param that Iterator to compare with
81      * @return true if the given iterator is not equal to this one, false - otherwise
82      */
83     bool operator!=(const CNNNetworkIterator &that) const {
84         return !operator==(that);
85     }
86
87     /**
88      * @brief Gets const layer pointer referenced by this iterator
89      */
90     const CNNLayerPtr &operator*() const {
91         if (nullptr == currentLayer) {
92             THROW_IE_EXCEPTION << "iterator out of bound";
93         }
94         return currentLayer;
95     }
96
97     /**
98      * @brief Gets a layer pointer referenced by this iterator
99      */
100     CNNLayerPtr &operator*() {
101         if (nullptr == currentLayer) {
102             THROW_IE_EXCEPTION << "iterator out of bound";
103         }
104         return currentLayer;
105     }
106     /**
107      * @brief Compares the given iterator with this one
108      * @param that Iterator to compare with
109      * @return true if the given iterator is equal to this one, false - otherwise
110      */
111     bool operator==(const CNNNetworkIterator &that) const {
112         return network == that.network && currentLayer == that.currentLayer;
113     }
114
115  private:
116     /**
117      * @brief implementation based on BFS
118      */
119     CNNLayerPtr next() {
120         if (nextLayersTovisit.empty()) {
121             return nullptr;
122         }
123
124         auto nextLayer = nextLayersTovisit.front();
125         nextLayersTovisit.pop_front();
126
127         // visit child that not visited
128         for (auto && output : nextLayer->outData) {
129             for (auto && child : output->getInputTo()) {
130                 if (visited.find(child.second.get()) == visited.end()) {
131                     nextLayersTovisit.push_back(child.second);
132                     visited.insert(child.second.get());
133                 }
134             }
135         }
136
137         // visit parents
138         for (auto && parent  : nextLayer->insData) {
139             auto parentLayer = parent.lock()->getCreatorLayer().lock();
140             if (parentLayer && visited.find(parentLayer.get()) == visited.end()) {
141                 nextLayersTovisit.push_back(parentLayer);
142                 visited.insert(parentLayer.get());
143             }
144         }
145
146         return nextLayersTovisit.empty() ? nullptr : nextLayersTovisit.front();
147     }
148 };
149 }  // namespace details
150 }  // namespace InferenceEngine