[neurun] Add setInput/setOutput into graph node (#2406)
author김수진/동작제어Lab(SR)/Engineer/삼성전자 <sjsujin.kim@samsung.com>
Wed, 22 Aug 2018 07:28:13 +0000 (16:28 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 22 Aug 2018 07:28:13 +0000 (16:28 +0900)
* [neurun] Add setInput/setOutput into graph node

This commit adds setInput/setOutput into graph node.

Signed-off-by: sjsujinkim sjsujin.kim@samsung.com
* change function names from set{In/Out}put to {in/out}put

* fix format in Set test

28 files changed:
runtimes/neurun/src/graph/operation/AvgPool2D.cc
runtimes/neurun/src/graph/operation/AvgPool2D.h
runtimes/neurun/src/graph/operation/Concat.cc
runtimes/neurun/src/graph/operation/Concat.h
runtimes/neurun/src/graph/operation/Conv2D.cc
runtimes/neurun/src/graph/operation/Conv2D.h
runtimes/neurun/src/graph/operation/FullyConnected.cc
runtimes/neurun/src/graph/operation/FullyConnected.h
runtimes/neurun/src/graph/operation/MaxPool2D.cc
runtimes/neurun/src/graph/operation/MaxPool2D.h
runtimes/neurun/src/graph/operation/NOP.cc
runtimes/neurun/src/graph/operation/NOP.h
runtimes/neurun/src/graph/operation/Node.h
runtimes/neurun/src/graph/operation/Reshape.cc
runtimes/neurun/src/graph/operation/Reshape.h
runtimes/neurun/src/graph/operation/Softmax.cc
runtimes/neurun/src/graph/operation/Softmax.h
runtimes/neurun/src/internal/op/AvgPool2D.h
runtimes/neurun/src/internal/op/Concat.h
runtimes/neurun/src/internal/op/Conv2D.h
runtimes/neurun/src/internal/op/FullyConnected.h
runtimes/neurun/src/internal/op/MaxPool2D.h
runtimes/neurun/src/internal/op/NOP.h
runtimes/neurun/src/internal/op/Reshape.h
runtimes/neurun/src/internal/op/Softmax.h
runtimes/neurun/test/graph/operation/Set.cc
runtimes/neurun/test/graph/operation/SetIO.cc [new file with mode: 0644]
runtimes/neurun/test/graph/verifier/Verifier.cc

index ed6f190..83af4de 100644 (file)
@@ -1,5 +1,7 @@
 #include "AvgPool2D.h"
 
+#include <cassert>
+
 namespace neurun
 {
 namespace graph
@@ -11,7 +13,21 @@ namespace AvgPool2D
 namespace Implicit
 {
 
-// NO IMPLEMENTATION YET
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().ifm_index = indexes.at(index).asInt();
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().ofm_index = indexes.at(index).asInt();
+}
 
 } // namespace Implicit
 } // namespace AvgPool2D
index bf76633..3268c6f 100644 (file)
@@ -27,6 +27,8 @@ public:
 public:
   virtual operand::IndexSet inputs() const override { return {_op->param().ifm_index}; }
   virtual operand::IndexSet outputs() const override { return {_op->param().ofm_index}; }
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index e52424c..bc27eb2 100644 (file)
@@ -1,5 +1,7 @@
 #include "Concat.h"
 
+#include <cassert>
+
 namespace neurun
 {
 namespace graph
@@ -20,6 +22,24 @@ operand::IndexSet Node::inputs() const
   return set;
 }
 
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  std::vector<int32_t> inds;
+  for (auto index : indexes.list())
+  {
+    inds.emplace_back(index.asInt());
+  }
+  _op->param().ifm_indexes = inds;
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().ofm_index = indexes.at(index).asInt();
+}
+
 } // namespace Concat
 } // namespace operation
 } // namespace graph
index feda9c8..f4f0faf 100644 (file)
@@ -23,6 +23,8 @@ public:
 public:
   virtual operand::IndexSet inputs() const override;
   virtual operand::IndexSet outputs() const override { return {_op->param().ofm_index}; }
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index 5f17931..d4d8b04 100644 (file)
@@ -1,5 +1,7 @@
 #include "Conv2D.h"
 
+#include <cassert>
+
 namespace neurun
 {
 namespace graph
@@ -11,7 +13,21 @@ namespace Conv2D
 namespace Implicit
 {
 
-// NO IMPLEMENTATION YET
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().ifm_index = indexes.at(index).asInt();
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().ofm_index = indexes.at(index).asInt();
+}
 
 } // namespace Implicit
 } // namespace Conv2D
index 22d0d99..ed97a4b 100644 (file)
@@ -28,6 +28,8 @@ public:
     return {_op->param().ifm_index, _op->param().ker_index, _op->param().bias_index};
   }
   virtual operand::IndexSet outputs() const override { return {_op->param().ofm_index}; }
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index 338d734..c0e4527 100644 (file)
@@ -1,5 +1,7 @@
 #include "FullyConnected.h"
 
+#include <cassert>
+
 namespace neurun
 {
 namespace graph
@@ -9,7 +11,21 @@ namespace operation
 namespace FullyConnected
 {
 
-// NO IMPLEMENTATION YET
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().input_index = indexes.at(index).asInt();
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().output_index = indexes.at(index).asInt();
+}
 
 } // namespace FullyConnected
 } // namespace operation
index 18ac24f..ca87e30 100644 (file)
@@ -26,6 +26,8 @@ public:
     return {_op->param().input_index, _op->param().weight_index, _op->param().bias_index};
   }
   virtual operand::IndexSet outputs() const override { return {_op->param().output_index}; }
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index 45bc3e6..5f12ac1 100644 (file)
@@ -1,5 +1,7 @@
 #include "MaxPool2D.h"
 
+#include <cassert>
+
 namespace neurun
 {
 namespace graph
@@ -11,7 +13,21 @@ namespace MaxPool2D
 namespace Implicit
 {
 
-// NO IMPLEMENTATION YET
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().ifm_index = indexes.at(index).asInt();
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().ofm_index = indexes.at(index).asInt();
+}
 
 } // namespace Implicit
 } // namespace MaxPool2D
index 0b63c7e..0630182 100644 (file)
@@ -27,6 +27,8 @@ public:
 public:
   virtual operand::IndexSet inputs() const override { return {_op->param().ifm_index}; }
   virtual operand::IndexSet outputs() const override { return {_op->param().ofm_index}; }
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index 66afdaf..2654057 100644 (file)
@@ -31,6 +31,26 @@ operand::IndexSet Node::outputs() const
   return set;
 }
 
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  std::vector<int32_t> inds;
+  for (auto index : indexes.list())
+  {
+    inds.emplace_back(index.asInt());
+  }
+  _op->param().ifm_indexes = inds;
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  std::vector<int32_t> inds;
+  for (auto index : indexes.list())
+  {
+    inds.emplace_back(index.asInt());
+  }
+  _op->param().ofm_indexes = inds;
+}
+
 } // namespace NOP
 } // namespace operation
 } // namespace graph
index 8d038e6..76e96de 100644 (file)
@@ -23,6 +23,8 @@ public:
 public:
   virtual operand::IndexSet inputs() const override;
   virtual operand::IndexSet outputs() const override;
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index a00466b..2da210b 100644 (file)
@@ -24,6 +24,9 @@ public:
 public:
   virtual operand::IndexSet inputs() const = 0;
   virtual operand::IndexSet outputs() const = 0;
+  // It's for only input/output tensors but const data.
+  virtual void inputs(const operand::IndexSet &indexes) = 0;
+  virtual void outputs(const operand::IndexSet &indexes) = 0;
   virtual const ::internal::tflite::op::Node *op() const = 0;
 
 public:
index 09f7d7e..68bd989 100644 (file)
@@ -1,5 +1,7 @@
 #include "Reshape.h"
 
+#include <cassert>
+
 namespace neurun
 {
 namespace graph
@@ -9,7 +11,21 @@ namespace operation
 namespace Reshape
 {
 
-// NO IMPLEMENTATION YET
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().input_index = indexes.at(index).asInt();
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().output_index = indexes.at(index).asInt();
+}
 
 } // namespace Reshape
 } // namespace operation
index facf29d..15b408e 100644 (file)
@@ -26,6 +26,8 @@ public:
     return {_op->param().input_index, _op->param().shape_index};
   }
   virtual operand::IndexSet outputs() const override { return {_op->param().output_index}; }
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index ffdd210..8c51aae 100644 (file)
@@ -1,5 +1,7 @@
 #include "Softmax.h"
 
+#include <cassert>
+
 namespace neurun
 {
 namespace graph
@@ -9,7 +11,21 @@ namespace operation
 namespace Softmax
 {
 
-// NO IMPLEMENTATION YET
+void Node::inputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().input_index = indexes.at(index).asInt();
+}
+
+void Node::outputs(const operand::IndexSet &indexes)
+{
+  assert(indexes.size() == 1);
+
+  ::neurun::graph::operand::IO::Index index{0};
+  _op->param().output_index = indexes.at(index).asInt();
+}
 
 } // namespace Softmax
 } // namespace operation
index 56beec4..17bf8c9 100644 (file)
@@ -23,6 +23,8 @@ public:
 public:
   virtual operand::IndexSet inputs() const override { return {_op->param().input_index}; }
   virtual operand::IndexSet outputs() const override { return {_op->param().output_index}; }
+  virtual void inputs(const operand::IndexSet &indexes) override;
+  virtual void outputs(const operand::IndexSet &indexes) override;
   virtual const ::internal::tflite::op::Node *op() const override { return _op.get(); }
 
 private:
index 1696878..8c99e56 100644 (file)
@@ -48,12 +48,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace implicit
index b4c2d0e..0b195d1 100644 (file)
@@ -39,12 +39,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace Concat
index 5764472..44184c9 100644 (file)
@@ -47,12 +47,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace implicit
index af0ba4c..a0c446b 100644 (file)
@@ -40,12 +40,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace FullyConnected
index f3f00b3..4470c9a 100644 (file)
@@ -48,12 +48,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace implicit
index 95043ad..4adcfa3 100644 (file)
@@ -38,12 +38,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace NOP
index 7a26506..8600c6b 100644 (file)
@@ -38,12 +38,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace Reshape
index 54434e2..f1b7624 100644 (file)
@@ -38,12 +38,13 @@ public:
 
 public:
   const Param &param(void) const { return _param; }
+  Param &param(void) { return _param; }
 
 public:
   void accept(NodeVisitor &&) const override;
 
 private:
-  const Param _param;
+  Param _param;
 };
 
 } // namespace Softmax
index 759618e..0e3413b 100644 (file)
@@ -14,6 +14,8 @@ public:
 public:
   virtual neurun::graph::operand::IndexSet inputs() const { return {1, 2, 3, 4}; }
   virtual neurun::graph::operand::IndexSet outputs() const { return {1, 2, 3}; }
+  virtual void inputs(const neurun::graph::operand::IndexSet &indexes) override {}
+  virtual void outputs(const neurun::graph::operand::IndexSet &indexes) override {}
   virtual const ::internal::tflite::op::Node *op() const { return nullptr; }
 };
 
diff --git a/runtimes/neurun/test/graph/operation/SetIO.cc b/runtimes/neurun/test/graph/operation/SetIO.cc
new file mode 100644 (file)
index 0000000..bc1e15f
--- /dev/null
@@ -0,0 +1,74 @@
+#include <gtest/gtest.h>
+
+#include "graph/Graph.h"
+#include "nnfw/std/memory.h"
+#include "graph/operation/Conv2D.h"
+#include "graph/operation/Concat.h"
+#include "graph/operand/Index.h"
+#include "graph/operand/IndexSet.h"
+
+#include <stdexcept>
+
+using Index = neurun::graph::operand::IO::Index;
+using IndexSet = neurun::graph::operand::IndexSet;
+
+TEST(graph_operation_setIO, operation_setIO_conv)
+{
+  neurun::graph::Graph graph;
+
+  neurun::internal::operand::Shape shape{1u};
+  shape.dim(0) = 3;
+
+  // Add Conv
+  std::vector<uint32_t> params;
+  for (int i = 0; i < 7; ++i)
+  {
+    params.emplace_back(graph.addOperand(shape).asInt());
+  }
+  uint32_t outoperand = graph.addOperand(shape).asInt();
+
+  using Param = internal::tflite::op::Conv2D::implicit::Param;
+  using Node = internal::tflite::op::Conv2D::implicit::Node;
+  using GraphNode = neurun::graph::operation::Conv2D::Implicit::Node;
+
+  auto conv = nnfw::make_unique<GraphNode>(
+      nnfw::make_unique<Node>(Param(7, params.data(), 1, &outoperand)));
+
+  ASSERT_EQ(conv->inputs().at(Index{0}).asInt(), params[0]);
+  conv->inputs({8});
+  ASSERT_NE(conv->inputs().at(Index{0}).asInt(), params[0]);
+  ASSERT_EQ(conv->inputs().at(Index{0}).asInt(), 8);
+}
+
+TEST(graph_operation_setIO, operation_setIO_concat)
+{
+  neurun::graph::Graph graph;
+
+  neurun::internal::operand::Shape shape{1u};
+  shape.dim(0) = 3;
+
+  // Add Concat
+  std::vector<uint32_t> params;
+  for (int i = 0; i < 7; ++i)
+  {
+    params.emplace_back(graph.addOperand(shape).asInt());
+  }
+  uint32_t outoperand = graph.addOperand(shape).asInt();
+
+  using Param = internal::tflite::op::Concat::Param;
+  using Node = internal::tflite::op::Concat::Node;
+  using GraphNode = neurun::graph::operation::Concat::Node;
+
+  auto concat = nnfw::make_unique<GraphNode>(
+      nnfw::make_unique<Node>(Param(7, params.data(), 1, &outoperand)));
+
+  ASSERT_EQ(concat->inputs().size(), 6);
+  ASSERT_EQ(concat->inputs().at(Index{0}).asInt(), params[0]);
+
+  concat->inputs({80, 6, 9, 11});
+  ASSERT_EQ(concat->inputs().size(), 4);
+  ASSERT_NE(concat->inputs().at(Index{0}).asInt(), params[0]);
+  ASSERT_EQ(concat->inputs().at(Index{0}).asInt(), 80);
+  ASSERT_EQ(concat->inputs().at(Index{2}).asInt(), 9);
+  ASSERT_THROW(concat->inputs().at(Index{5}), std::out_of_range);
+}
index a24d1dc..fc39488 100644 (file)
@@ -17,6 +17,8 @@ public:
 public:
   virtual neurun::graph::operand::IndexSet inputs() const override { return {_input}; }
   virtual neurun::graph::operand::IndexSet outputs() const override { return {_output}; }
+  virtual void inputs(const neurun::graph::operand::IndexSet &indexes) override {}
+  virtual void outputs(const neurun::graph::operand::IndexSet &indexes) override {}
   virtual const ::internal::tflite::op::Node *op() const override { return nullptr; }
 
 private: