1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
6 * @brief A header file for the CNNNetworkIterator class
7 * @file ie_cnn_network_iterator.hpp
11 #include <unordered_set>
15 #include "ie_locked_memory.hpp"
16 #include "ie_icnn_network.hpp"
18 namespace InferenceEngine {
22 * @brief This class enables range loops for CNNNetwork objects
24 class CNNNetworkIterator {
25 std::unordered_set<CNNLayer*> visited;
26 std::list<CNNLayerPtr> nextLayersTovisit;
27 InferenceEngine::CNNLayerPtr currentLayer;
28 ICNNNetwork * network = nullptr;
32 * iterator trait definitions
34 typedef std::forward_iterator_tag iterator_category;
35 typedef CNNLayerPtr value_type;
36 typedef int difference_type;
37 typedef CNNLayerPtr pointer;
38 typedef CNNLayerPtr reference;
41 * @brief Default constructor
43 CNNNetworkIterator() = default;
45 * @brief Constructor. Creates an iterator for specified CNNNetwork instance.
46 * @param network Network to iterate. Make sure the network object is not destroyed before iterator goes out of scope.
48 explicit CNNNetworkIterator(ICNNNetwork * network) {
50 network->getInputsInfo(inputs);
51 if (!inputs.empty()) {
52 auto & nextLayers = inputs.begin()->second->getInputData()->getInputTo();
53 if (!nextLayers.empty()) {
54 currentLayer = nextLayers.begin()->second;
55 nextLayersTovisit.push_back(currentLayer);
56 visited.insert(currentLayer.get());
62 * @brief Performs pre-increment
63 * @return This CNNNetworkIterator instance
65 CNNNetworkIterator &operator++() {
66 currentLayer = next();
71 * @brief Performs post-increment.
72 * Implementation does not follow the std interface since only move semantics is used
74 void operator++(int) {
75 currentLayer = next();
79 * @brief Checks if the given iterator is not equal to this one
80 * @param that Iterator to compare with
81 * @return true if the given iterator is not equal to this one, false - otherwise
83 bool operator!=(const CNNNetworkIterator &that) const {
84 return !operator==(that);
88 * @brief Gets const layer pointer referenced by this iterator
90 const CNNLayerPtr &operator*() const {
91 if (nullptr == currentLayer) {
92 THROW_IE_EXCEPTION << "iterator out of bound";
98 * @brief Gets a layer pointer referenced by this iterator
100 CNNLayerPtr &operator*() {
101 if (nullptr == currentLayer) {
102 THROW_IE_EXCEPTION << "iterator out of bound";
107 * @brief Compares the given iterator with this one
108 * @param that Iterator to compare with
109 * @return true if the given iterator is equal to this one, false - otherwise
111 bool operator==(const CNNNetworkIterator &that) const {
112 return network == that.network && currentLayer == that.currentLayer;
117 * @brief implementation based on BFS
120 if (nextLayersTovisit.empty()) {
124 auto nextLayer = nextLayersTovisit.front();
125 nextLayersTovisit.pop_front();
127 // visit child that not visited
128 for (auto && output : nextLayer->outData) {
129 for (auto && child : output->getInputTo()) {
130 if (visited.find(child.second.get()) == visited.end()) {
131 nextLayersTovisit.push_back(child.second);
132 visited.insert(child.second.get());
138 for (auto && parent : nextLayer->insData) {
139 auto parentLayer = parent.lock()->getCreatorLayer().lock();
140 if (parentLayer && visited.find(parentLayer.get()) == visited.end()) {
141 nextLayersTovisit.push_back(parentLayer);
142 visited.insert(parentLayer.get());
146 return nextLayersTovisit.empty() ? nullptr : nextLayersTovisit.front();
149 } // namespace details
150 } // namespace InferenceEngine