From f0150bed4c5456d112359f1b3fe5e4d50ffcf811 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=ED=95=9C=EC=A2=85/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Tue, 4 Sep 2018 13:07:38 +0900 Subject: [PATCH] [neurun] Enhance DFS implementation (#2561) Now that we have operand def/use info, we can do DFS efficiently. This is applied to PostDfsIterator. This might also be applied to DAGVerifier, but operand def/use info is built when Graph building is finished. Signed-off-by: Hanjoung Lee --- runtimes/neurun/src/graph/Graph.cc | 23 ++++++++--------------- runtimes/neurun/src/graph/Graph.h | 2 +- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/runtimes/neurun/src/graph/Graph.cc b/runtimes/neurun/src/graph/Graph.cc index 04e08ff..ef97c84 100644 --- a/runtimes/neurun/src/graph/Graph.cc +++ b/runtimes/neurun/src/graph/Graph.cc @@ -188,6 +188,8 @@ void Graph::DefaultIterator::iterate(GraphRef graph, const IterFn &fn) template void Graph::PostDfsIterator::iterate(GraphRef graph, const IterFn &fn) const { + assert(!graph.isBuildingPhase()); // Restrict iteration condition + std::vector visited(graph._operations.size(), false); std::function dfs_recursive = @@ -196,22 +198,13 @@ void Graph::PostDfsIterator::iterate(GraphRef graph, const IterFn &fn) return; visited[index.asInt()] = true; - auto outputs = node.getOutputs(); - for (auto output : outputs) + for (auto output : node.getOutputs()) { - // TODO Fix traversing algorithm - // Every time need to search for operations that has `outgoing` as incoming from all - // operations but we can hold that info cached - graph._operations.iterate([&](const operation::Index &cand_index, NodeRef cand_node) -> void { - auto inputs = cand_node.getInputs(); - for (auto input : inputs) - { - if (output == input) - { - dfs_recursive(cand_index, cand_node); - } - } - }); + const auto &operand = graph._operands.at(output); + for (const auto &use : operand.getUses().list()) + { + dfs_recursive(use, graph._operations.at(use)); + } } fn(node); diff --git a/runtimes/neurun/src/graph/Graph.h b/runtimes/neurun/src/graph/Graph.h index 1e217cf..c9d1e38 100644 --- a/runtimes/neurun/src/graph/Graph.h +++ b/runtimes/neurun/src/graph/Graph.h @@ -86,7 +86,7 @@ public: void finishBuilding(void); void lower(void); std::unique_ptr linearize(void); - bool isBuildingPhase(void) { return _phase == Phase::BUILDING; } + bool isBuildingPhase(void) const { return _phase == Phase::BUILDING; } private: void initializeUseDef(); -- 2.7.4