Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / ade / ade / source / check_cycles.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include "passes/check_cycles.hpp"
7
8 #include <unordered_map>
9
10 #include "util/assert.hpp"
11 #include "util/map_range.hpp"
12
13 #include "graph.hpp"
14 #include "node.hpp"
15
16 namespace ade
17 {
18 namespace passes
19 {
20 enum class TraverseState
21 {
22     visiting,
23     visited,
24 };
25
26 using state_t = std::unordered_map<Node*, TraverseState>;
27
28 static void visit(state_t& state, const NodeHandle& node)
29 {
30     ASSERT(nullptr != node);
31     state[node.get()] = TraverseState::visiting;
32     for (auto adj:
33          util::map(node->outEdges(), [](const EdgeHandle& e) { return e->dstNode(); }))
34     {
35         auto it = state.find(adj.get());
36         if (state.end() == it) // not visited
37         {
38             visit(state, adj);
39         }
40         else if (TraverseState::visiting == it->second)
41         {
42             throw_error(CycleFound());
43         }
44     }
45     state[node.get()] = TraverseState::visited;
46
47 }
48
49 void CheckCycles::operator()(const PassContext& context) const
50 {
51     state_t state;
52     for (auto node: context.graph.nodes())
53     {
54         if (state.end() == state.find(node.get()))
55         {
56             // not yet visited during recursion
57             visit(state, node);
58         }
59     }
60 }
61
62 std::string CheckCycles::name()
63 {
64     return "CheckCycles";
65 }
66
67 const char* CycleFound::what() const noexcept
68 {
69     return "Cycle was detected in graph";
70 }
71
72 }
73 }