[neurun] Introduce NodeVisitor for Graph operation (#2405)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Wed, 22 Aug 2018 08:26:24 +0000 (17:26 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 22 Aug 2018 08:26:24 +0000 (17:26 +0900)
Introduce a NodeVisitor `neurun::graph::operation::Node` which is a
visitor interface for our graph nodes(operations). Once this class
gets used, NodeVisitor for `internal::tflite::op::Node` will no longer
be used.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
21 files changed:
runtimes/neurun/src/graph/operation/AvgPool2D.cc
runtimes/neurun/src/graph/operation/AvgPool2D.h
runtimes/neurun/src/graph/operation/Concat.cc
runtimes/neurun/src/graph/operation/Concat.h
runtimes/neurun/src/graph/operation/Conv2D.cc
runtimes/neurun/src/graph/operation/Conv2D.h
runtimes/neurun/src/graph/operation/FullyConnected.cc
runtimes/neurun/src/graph/operation/FullyConnected.h
runtimes/neurun/src/graph/operation/MaxPool2D.cc
runtimes/neurun/src/graph/operation/MaxPool2D.h
runtimes/neurun/src/graph/operation/NOP.cc
runtimes/neurun/src/graph/operation/NOP.h
runtimes/neurun/src/graph/operation/Node.cc [new file with mode: 0644]
runtimes/neurun/src/graph/operation/Node.h
runtimes/neurun/src/graph/operation/NodeVisitor.h [new file with mode: 0644]
runtimes/neurun/src/graph/operation/Reshape.cc
runtimes/neurun/src/graph/operation/Reshape.h
runtimes/neurun/src/graph/operation/Softmax.cc
runtimes/neurun/src/graph/operation/Softmax.h
runtimes/neurun/test/graph/operation/Set.cc
runtimes/neurun/test/graph/verifier/Verifier.cc

index 83af4de..2aa48b2 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -13,6 +15,8 @@ namespace AvgPool2D
 namespace Implicit
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 void Node::inputs(const operand::IndexSet &indexes)
 {
   assert(indexes.size() == 1);
index 3268c6f..8b4da0b 100644 (file)
@@ -25,6 +25,9 @@ public:
   }
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override { return {_op->param().ifm_index}; }
   virtual operand::IndexSet outputs() const override { return {_op->param().ofm_index}; }
   virtual void inputs(const operand::IndexSet &indexes) override;
index bc27eb2..61e6835 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -11,6 +13,8 @@ namespace operation
 namespace Concat
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 operand::IndexSet Node::inputs() const
 {
   operand::IndexSet set;
index f4f0faf..5166988 100644 (file)
@@ -21,6 +21,9 @@ public:
   Node(std::unique_ptr<::internal::tflite::op::Concat::Node> &&op) : _op{std::move(op)} {}
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override;
   virtual operand::IndexSet outputs() const override { return {_op->param().ofm_index}; }
   virtual void inputs(const operand::IndexSet &indexes) override;
index d4d8b04..cf7381e 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -13,6 +15,8 @@ namespace Conv2D
 namespace Implicit
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 void Node::inputs(const operand::IndexSet &indexes)
 {
   assert(indexes.size() == 1);
index ed97a4b..8021f86 100644 (file)
@@ -23,6 +23,9 @@ public:
   Node(std::unique_ptr<::internal::tflite::op::Conv2D::implicit::Node> &&op) : _op{std::move(op)} {}
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override
   {
     return {_op->param().ifm_index, _op->param().ker_index, _op->param().bias_index};
index c0e4527..cdb9de3 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -11,6 +13,8 @@ namespace operation
 namespace FullyConnected
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 void Node::inputs(const operand::IndexSet &indexes)
 {
   assert(indexes.size() == 1);
index ca87e30..396dfe0 100644 (file)
@@ -21,6 +21,9 @@ public:
   Node(std::unique_ptr<::internal::tflite::op::FullyConnected::Node> &&op) : _op{std::move(op)} {}
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override
   {
     return {_op->param().input_index, _op->param().weight_index, _op->param().bias_index};
index 5f12ac1..eb79310 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -13,6 +15,8 @@ namespace MaxPool2D
 namespace Implicit
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 void Node::inputs(const operand::IndexSet &indexes)
 {
   assert(indexes.size() == 1);
index 0630182..65f829f 100644 (file)
@@ -25,6 +25,9 @@ public:
   }
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override { return {_op->param().ifm_index}; }
   virtual operand::IndexSet outputs() const override { return {_op->param().ofm_index}; }
   virtual void inputs(const operand::IndexSet &indexes) override;
index 2654057..07fcf29 100644 (file)
@@ -1,5 +1,7 @@
 #include "NOP.h"
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -9,6 +11,8 @@ namespace operation
 namespace NOP
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 operand::IndexSet Node::inputs() const
 {
   operand::IndexSet set;
index 76e96de..25490fe 100644 (file)
@@ -21,6 +21,9 @@ public:
   Node(std::unique_ptr<::internal::tflite::op::NOP::Node> &&op) : _op{std::move(op)} {}
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override;
   virtual operand::IndexSet outputs() const override;
   virtual void inputs(const operand::IndexSet &indexes) override;
diff --git a/runtimes/neurun/src/graph/operation/Node.cc b/runtimes/neurun/src/graph/operation/Node.cc
new file mode 100644 (file)
index 0000000..ae02da2
--- /dev/null
@@ -0,0 +1,14 @@
+#include "Node.h"
+
+namespace neurun
+{
+namespace graph
+{
+namespace operation
+{
+
+// NO IMPLEMENTATION YET
+
+} // namespace operation
+} // namespace graph
+} // namespace neurun
index 2da210b..27d7abb 100644 (file)
@@ -15,6 +15,7 @@ namespace operation
 {
 
 class LowerInfo;
+struct NodeVisitor;
 
 class Node
 {
@@ -22,6 +23,9 @@ public:
   virtual ~Node() = default;
 
 public:
+  virtual void accept(NodeVisitor &) const = 0;
+
+public:
   virtual operand::IndexSet inputs() const = 0;
   virtual operand::IndexSet outputs() const = 0;
   // It's for only input/output tensors but const data.
diff --git a/runtimes/neurun/src/graph/operation/NodeVisitor.h b/runtimes/neurun/src/graph/operation/NodeVisitor.h
new file mode 100644 (file)
index 0000000..3c7b67d
--- /dev/null
@@ -0,0 +1,38 @@
+#ifndef __NEURUN_GRAPH_OPERATION_NODE_VISITOR_H__
+#define __NEURUN_GRAPH_OPERATION_NODE_VISITOR_H__
+
+#include "Conv2D.h"
+#include "MaxPool2D.h"
+#include "AvgPool2D.h"
+#include "Concat.h"
+#include "Reshape.h"
+#include "FullyConnected.h"
+#include "Softmax.h"
+#include "NOP.h"
+
+namespace neurun
+{
+namespace graph
+{
+namespace operation
+{
+
+struct NodeVisitor
+{
+  virtual ~NodeVisitor() = default;
+
+  virtual void visit(const Conv2D::Implicit::Node &) = 0;
+  virtual void visit(const MaxPool2D::Implicit::Node &) = 0;
+  virtual void visit(const AvgPool2D::Implicit::Node &) = 0;
+  virtual void visit(const Concat::Node &) = 0;
+  virtual void visit(const Reshape::Node &) = 0;
+  virtual void visit(const FullyConnected::Node &) = 0;
+  virtual void visit(const Softmax::Node &) = 0;
+  virtual void visit(const NOP::Node &) = 0;
+};
+
+} // namespace operation
+} // namespace graph
+} // namespace neurun
+
+#endif // __NEURUN_GRAPH_OPERATION_NODE_VISITOR_H__
index 68bd989..a72d81c 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -11,6 +13,8 @@ namespace operation
 namespace Reshape
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 void Node::inputs(const operand::IndexSet &indexes)
 {
   assert(indexes.size() == 1);
index 15b408e..11b712e 100644 (file)
@@ -21,6 +21,9 @@ public:
   Node(std::unique_ptr<::internal::tflite::op::Reshape::Node> &&op) : _op{std::move(op)} {}
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override
   {
     return {_op->param().input_index, _op->param().shape_index};
index 8c51aae..f3ca861 100644 (file)
@@ -2,6 +2,8 @@
 
 #include <cassert>
 
+#include "NodeVisitor.h"
+
 namespace neurun
 {
 namespace graph
@@ -11,6 +13,8 @@ namespace operation
 namespace Softmax
 {
 
+void Node::accept(NodeVisitor &v) const { v.visit(*this); }
+
 void Node::inputs(const operand::IndexSet &indexes)
 {
   assert(indexes.size() == 1);
index 17bf8c9..06583b1 100644 (file)
@@ -21,6 +21,9 @@ public:
   Node(std::unique_ptr<::internal::tflite::op::Softmax::Node> &&op) : _op{std::move(op)} {}
 
 public:
+  virtual void accept(NodeVisitor &) const override;
+
+public:
   virtual operand::IndexSet inputs() const override { return {_op->param().input_index}; }
   virtual operand::IndexSet outputs() const override { return {_op->param().output_index}; }
   virtual void inputs(const operand::IndexSet &indexes) override;
index 0e3413b..e63a52c 100644 (file)
@@ -12,6 +12,9 @@ public:
   TestNode() = default;
 
 public:
+  virtual void accept(neurun::graph::operation::NodeVisitor &) const override {}
+
+public:
   virtual neurun::graph::operand::IndexSet inputs() const { return {1, 2, 3, 4}; }
   virtual neurun::graph::operand::IndexSet outputs() const { return {1, 2, 3}; }
   virtual void inputs(const neurun::graph::operand::IndexSet &indexes) override {}
index fc39488..db2fd94 100644 (file)
@@ -15,6 +15,9 @@ public:
   }
 
 public:
+  virtual void accept(neurun::graph::operation::NodeVisitor &) const override {}
+
+public:
   virtual neurun::graph::operand::IndexSet inputs() const override { return {_input}; }
   virtual neurun::graph::operand::IndexSet outputs() const override { return {_output}; }
   virtual void inputs(const neurun::graph::operand::IndexSet &indexes) override {}