[neurun] Implement DAG Checker (#2312)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Thu, 16 Aug 2018 08:10:42 +0000 (17:10 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 16 Aug 2018 08:10:42 +0000 (17:10 +0900)
A checker that tests if given graph is a directed-acyclic-graph.
If there is a cycle the checker fails, otherwise passes.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
runtimes/neurun/src/graph/Graph.h
runtimes/neurun/src/graph/verifier/IVerifier.cc
runtimes/neurun/test/graph/verifier/Verifier.cc [new file with mode: 0644]

index 68e5523..d2e72ba 100644 (file)
@@ -43,6 +43,7 @@ public:
   const operand::IndexSet &outputs() const { return _outputs; }
   const operand::Set &operands() const { return _operands; }
   operand::Set &operands() { return _operands; } // TODO Remove this non-const accessor
+  const operation::Set &operations() const { return _operations; }
 
 public:
   // TODO Introduce Iterator class to support many kinds of interation
index 05bdbe5..6493d49 100644 (file)
@@ -9,10 +9,46 @@ namespace graph
 namespace verifier
 {
 
-bool DAGChecker::verify(const Graph &) const
+bool DAGChecker::verify(const Graph &graph) const
 {
-  // TODO Implement DAG check
-  return true;
+  auto &operations = graph.operations();
+  bool cyclic = false;
+  std::vector<bool> visited(operations.size(), false);
+  std::vector<bool> on_stack(operations.size(), false);
+
+  std::function<void(const operation::Index &index, const operation::Node &)> dfs_recursive =
+      [&](const operation::Index &index, const operation::Node &node) -> void {
+    if (on_stack[index.value()])
+      cyclic = true;
+    if (visited[index.value()])
+      return;
+    visited[index.value()] = true;
+    on_stack[index.value()] = true;
+
+    auto outputs = node.outputs();
+    for (auto output : outputs.list())
+    {
+      // TODO Fix traversing algorithm
+      //      Every time need to search for operations that has `outgoing` as incoming from all
+      //      operations but we can hold that info cached
+      operations.iterate([&](const operation::Index &cand_index, const operation::Node &cand_node) {
+        auto inputs = cand_node.inputs();
+        for (auto input : inputs.list())
+        {
+          if (output == input)
+          {
+            dfs_recursive(cand_index, cand_node);
+          }
+        }
+      });
+    }
+
+    on_stack[index.value()] = false;
+  };
+
+  operations.iterate(dfs_recursive);
+
+  return !cyclic;
 }
 
 } // namespace verifier
diff --git a/runtimes/neurun/test/graph/verifier/Verifier.cc b/runtimes/neurun/test/graph/verifier/Verifier.cc
new file mode 100644 (file)
index 0000000..2030668
--- /dev/null
@@ -0,0 +1,49 @@
+#include <gtest/gtest.h>
+
+#include "graph/operation/Node.h"
+#include "graph/Graph.h"
+#include "graph/verifier/IVerifier.h"
+#include "nnfw/std/memory.h"
+
+class MockNode : public neurun::graph::operation::Node
+{
+public:
+  MockNode(neurun::graph::operand::Index input, neurun::graph::operand::Index output)
+      : _input{input}, _output{output}
+  {
+    // DO NOTHING
+  }
+
+public:
+  virtual neurun::graph::operand::IndexSet inputs() const override { return {_input}; }
+  virtual neurun::graph::operand::IndexSet outputs() const override { return {_output}; }
+  virtual const ::internal::tflite::op::Node *op() const override { return nullptr; }
+
+private:
+  neurun::graph::operand::Index _input;
+  neurun::graph::operand::Index _output;
+};
+
+TEST(Verifier, dag_checker)
+{
+  neurun::graph::Graph graph;
+  neurun::graph::verifier::DAGChecker verifier;
+
+  internal::tflite::operand::Shape shape{1u};
+  shape.dim(0) = 3;
+
+  auto operand1 = graph.addOperand(shape);
+  auto operand2 = graph.addOperand(shape);
+
+  graph.addInput(operand1);
+  graph.addOutput(operand2);
+
+  graph.addOperation(nnfw::make_unique<MockNode>(operand1, operand2));
+
+  ASSERT_EQ(verifier.verify(graph), true);
+
+  // Create cycle
+  graph.addOperation(nnfw::make_unique<MockNode>(operand2, operand1));
+
+  ASSERT_EQ(verifier.verify(graph), false);
+}