From 42c1600cf5fa6659a941d51e33dc6a73e7a2a63e Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=ED=95=9C=EC=A2=85/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Tue, 27 Nov 2018 15:28:46 +0900 Subject: [PATCH] [neurun] Fix Dfs Iterator (#3721) As mentioned from #3617 by @d-poshshoev, Visit checking assumed ID range as vector which is wrong. This commit changes variable `visit` to be a map so it can correctly check nodes for sure. Signed-off-by: Hanjoung Lee --- runtimes/neurun/src/graph/Graph.cc | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/runtimes/neurun/src/graph/Graph.cc b/runtimes/neurun/src/graph/Graph.cc index aa26e92..33aa25c 100644 --- a/runtimes/neurun/src/graph/Graph.cc +++ b/runtimes/neurun/src/graph/Graph.cc @@ -281,13 +281,15 @@ void Graph::PostDfsIterator::iterate(GraphRef graph, const IterFn &fn) { assert(!graph.isBuildingPhase()); // Restrict iteration condition - std::vector visited(graph.operations().size(), false); + std::unordered_map visited; + graph.operations().iterate( + [&](const operation::Index &index, NodeRef) { visited[index] = false; }); std::function dfs_recursive = [&](const operation::Index &index, NodeRef node) -> void { - if (visited[index.asInt()]) + if (visited[index]) return; - visited[index.asInt()] = true; + visited[index] = true; for (auto output : node.getOutputs()) { @@ -304,7 +306,8 @@ void Graph::PostDfsIterator::iterate(GraphRef graph, const IterFn &fn) graph.operations().iterate(dfs_recursive); // All of the operations(nodes) must have been visited. - assert(std::all_of(visited.begin(), visited.end(), [](bool v) { return v; })); + assert(std::all_of(visited.begin(), visited.end(), + [](const std::pair &v) { return v.second; })); } } // namespace graph -- 2.7.4