Add Graph class (#297)
authorVladimir Plazun/AI Tools Lab/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Mon, 4 Jun 2018 13:53:09 +0000 (17:53 +0400)
committerSergey Vostokov/AI Tools Lab/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Mon, 4 Jun 2018 13:53:09 +0000 (16:53 +0300)
Add Graph class

This class is used to represent computation graph

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/core/include/nnc/core/IR/model/graph/graph.h [new file with mode: 0644]
contrib/nnc/libs/core/include/nnc/core/IR/model/visitor/visitor.h
contrib/nnc/libs/core/src/core/IR/model/graph/graph.cpp [new file with mode: 0644]

diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/model/graph/graph.h b/contrib/nnc/libs/core/include/nnc/core/IR/model/graph/graph.h
new file mode 100644 (file)
index 0000000..71e7dc0
--- /dev/null
@@ -0,0 +1,73 @@
+#ifndef _NNC_CORE_IR_MODEL_GRAPH_H_
+#define _NNC_CORE_IR_MODEL_GRAPH_H_
+
+#include <string>
+#include <vector>
+#include <type_traits>
+#include <unordered_map>
+
+#include "nnc/core/IR/model/operations/operation.h"
+#include "nnc/core/IR/model/operations/variable_op.h"
+#include "nnc/core/IR/model/graph/ir_node.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace IR
+{
+namespace model {
+
+using ADT::INode;
+class Visitor;
+
+class Graph {
+ public:
+  explicit Graph() = default;
+
+  virtual ~Graph();
+
+  template<typename T, typename ...Args>
+  //make this method callable only with OpDescription subclasses
+  typename std::enable_if<std::is_convertible<T*, OpDescription*>::value, INode::Ref>::type
+  create(const std::string &name, Args &&...args) {
+    auto node = Node<T>::createNode(name, _nodes.size(), std::forward<Args>(args)...);
+    registerNode(node);
+    return node;
+  }
+
+  void accept(Visitor *visitor);
+
+  void markOutput(INode::Ref node);
+  const INode::Ref getInput(const std::string &name);
+  const INode::Ref getOutput(const std::string &name);
+
+ private:
+  void registerNode(INode::Ref node) {
+    _nodes.push_back(node);
+  }
+
+  //TODO: maybe make user to mark input _nodes in a more obvious way
+  void registerNode(Node<ops::VariableOp> *node) {
+    auto it = _inputs.find(node->getName());
+    if( it != _inputs.end()) {
+      throw std::runtime_error("Input name collision");
+    }
+    _inputs.insert(it, {node->getName(), node});
+    _nodes.push_back(node);
+  }
+  
+  std::vector<INode::Ref> _nodes;
+  std::unordered_map<std::string, INode::Ref> _inputs;
+  std::unordered_map<std::string, INode::Ref> _outputs;
+};
+
+} // namespace model
+} // namespace IR
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_IR_MODEL_GRAPH_H_
index 096478e..1bb8185 100644 (file)
@@ -7,14 +7,17 @@ namespace core {
 namespace IR {
 namespace model {
 
-class INode;
+//Forward declare INode due to circular dependecies with INode::accept(Visitor*);
+namespace ADT {
+  class INode;
+}
 
 class Visitor {
  public:
   // Operation-specific methods like:
   // virtual void visit(INode* node, OpDescriptionSubclass& op) = 0;
 
-  virtual void visit(INode *node) = 0;
+  virtual void visit(ADT::INode *node) = 0;
   virtual ~Visitor() = default;
 };
 
diff --git a/contrib/nnc/libs/core/src/core/IR/model/graph/graph.cpp b/contrib/nnc/libs/core/src/core/IR/model/graph/graph.cpp
new file mode 100644 (file)
index 0000000..6be333a
--- /dev/null
@@ -0,0 +1,77 @@
+#include <deque>
+#include <set>
+
+#include "nnc/core/IR/model/graph/graph.h"
+#include "nnc/core/IR/model/graph/ir_node.h"
+#include "nnc/core/IR/model/operations/operation.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace IR
+{
+namespace model {
+
+const INode::Ref Graph::getInput(const std::string &name) {
+  return _inputs.find(name)->second;
+}
+
+const INode::Ref Graph::getOutput(const std::string &name) {
+  return _outputs.find(name)->second;
+}
+
+void Graph::accept(Visitor *visitor) {
+  std::deque<INode::Ref> q;
+  std::set<INode::Ref> known_nodes;
+
+  for (const auto &e : _inputs) {
+    q.push_back(e.second);
+    known_nodes.insert(e.second); //Consider all input _nodes resolved by default
+  }
+
+  //BFS
+  while (!q.empty()) {
+    auto n = q.front();
+    q.pop_front();
+    n->accept(visitor);
+    for (auto out : n->getNextNodes()) {
+      if (known_nodes.count(out) == 0) {
+
+        bool allInputsResolved = true;
+        for (auto in : out->getPrevNodes()) {
+          if (known_nodes.count(in.node) == 0) {
+            allInputsResolved = false;
+          }
+        }
+        if (allInputsResolved) {
+          known_nodes.insert(out);
+          q.push_back(out);
+        }
+      }
+    }
+  }
+}
+
+Graph::~Graph() {
+  for (auto &node : _nodes) {
+    delete node;
+  }
+}
+
+void Graph::markOutput(INode::Ref node) {
+  auto it = _outputs.find(node->getName());
+  if (it != _outputs.end()) {
+    throw std::runtime_error("Output node with same name already exists");
+  }
+
+  _outputs[node->getName()] = node;
+}
+
+} // namespace model
+} // namespace IR
+} // namespace core
+} // namespace contrib
+} // namespace nncc