From c906cd92810e147b033c870543bf88f849a7e132 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=ED=95=9C=EC=A2=85/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Thu, 16 Aug 2018 17:10:42 +0900 Subject: [PATCH] [neurun] Implement DAG Checker (#2312) A checker that tests if given graph is a directed-acyclic-graph. If there is a cycle the checker fails, otherwise passes. Signed-off-by: Hanjoung Lee --- runtimes/neurun/src/graph/Graph.h | 1 + runtimes/neurun/src/graph/verifier/IVerifier.cc | 42 +++++++++++++++++++-- runtimes/neurun/test/graph/verifier/Verifier.cc | 49 +++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 runtimes/neurun/test/graph/verifier/Verifier.cc diff --git a/runtimes/neurun/src/graph/Graph.h b/runtimes/neurun/src/graph/Graph.h index 68e5523..d2e72ba 100644 --- a/runtimes/neurun/src/graph/Graph.h +++ b/runtimes/neurun/src/graph/Graph.h @@ -43,6 +43,7 @@ public: const operand::IndexSet &outputs() const { return _outputs; } const operand::Set &operands() const { return _operands; } operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor + const operation::Set &operations() const { return _operations; } public: // TODO Introduce Iterator class to support many kinds of interation diff --git a/runtimes/neurun/src/graph/verifier/IVerifier.cc b/runtimes/neurun/src/graph/verifier/IVerifier.cc index 05bdbe5..6493d49 100644 --- a/runtimes/neurun/src/graph/verifier/IVerifier.cc +++ b/runtimes/neurun/src/graph/verifier/IVerifier.cc @@ -9,10 +9,46 @@ namespace graph namespace verifier { -bool DAGChecker::verify(const Graph &) const +bool DAGChecker::verify(const Graph &graph) const { - // TODO Implement DAG check - return true; + auto &operations = graph.operations(); + bool cyclic = false; + std::vector visited(operations.size(), false); + std::vector on_stack(operations.size(), false); + + std::function dfs_recursive = + [&](const operation::Index &index, const operation::Node &node) -> void { + if (on_stack[index.value()]) + cyclic = true; + if (visited[index.value()]) + return; + visited[index.value()] = true; + on_stack[index.value()] = true; + + auto outputs = node.outputs(); + for (auto output : outputs.list()) + { + // TODO Fix traversing algorithm + // Every time need to search for operations that has `outgoing` as incoming from all + // operations but we can hold that info cached + operations.iterate([&](const operation::Index &cand_index, const operation::Node &cand_node) { + auto inputs = cand_node.inputs(); + for (auto input : inputs.list()) + { + if (output == input) + { + dfs_recursive(cand_index, cand_node); + } + } + }); + } + + on_stack[index.value()] = false; + }; + + operations.iterate(dfs_recursive); + + return !cyclic; } } // namespace verifier diff --git a/runtimes/neurun/test/graph/verifier/Verifier.cc b/runtimes/neurun/test/graph/verifier/Verifier.cc new file mode 100644 index 0000000..2030668 --- /dev/null +++ b/runtimes/neurun/test/graph/verifier/Verifier.cc @@ -0,0 +1,49 @@ +#include + +#include "graph/operation/Node.h" +#include "graph/Graph.h" +#include "graph/verifier/IVerifier.h" +#include "nnfw/std/memory.h" + +class MockNode : public neurun::graph::operation::Node +{ +public: + MockNode(neurun::graph::operand::Index input, neurun::graph::operand::Index output) + : _input{input}, _output{output} + { + // DO NOTHING + } + +public: + virtual neurun::graph::operand::IndexSet inputs() const override { return {_input}; } + virtual neurun::graph::operand::IndexSet outputs() const override { return {_output}; } + virtual const ::internal::tflite::op::Node *op() const override { return nullptr; } + +private: + neurun::graph::operand::Index _input; + neurun::graph::operand::Index _output; +}; + +TEST(Verifier, dag_checker) +{ + neurun::graph::Graph graph; + neurun::graph::verifier::DAGChecker verifier; + + internal::tflite::operand::Shape shape{1u}; + shape.dim(0) = 3; + + auto operand1 = graph.addOperand(shape); + auto operand2 = graph.addOperand(shape); + + graph.addInput(operand1); + graph.addOutput(operand2); + + graph.addOperation(nnfw::make_unique(operand1, operand2)); + + ASSERT_EQ(verifier.verify(graph), true); + + // Create cycle + graph.addOperation(nnfw::make_unique(operand2, operand1)); + + ASSERT_EQ(verifier.verify(graph), false); +} -- 2.7.4