[loco] Simplify Identity Graph Testcase with Graph Builder (#6344)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 8 Aug 2019 05:42:38 +0000 (14:42 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 8 Aug 2019 05:42:38 +0000 (14:42 +0900)
* [loco] Simplify Identity Graph Testcase with Graph Builder

This commit simplifies Identity Graph Testcase with GraphBuilder.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Fix a typo

compiler/loco/src/Service/GraphBuilder.h
compiler/loco/src/Service/GraphTestcase.h

index 49cd58f..8c48755 100644 (file)
@@ -111,4 +111,102 @@ static inline std::unique_ptr<GraphBuilder> make_graph_builder(loco::Graph *g)
   return stdex::make_unique<GraphBuilder>(g);
 }
 
+// "InputLayer" creates both GraphInput and Pull node at once
+struct InputLayer final
+{
+  class Return
+  {
+  public:
+    Return(loco::GraphInput *input, loco::Pull *node) : _input{input}, _node{node}
+    {
+      // DO NOTHING
+    }
+
+  public:
+    loco::Pull *node(void) { return _node; }
+
+  public:
+    Return *name(const std::string &value)
+    {
+      _input->name(value);
+      return this;
+    }
+
+  public:
+    Return *shape(std::initializer_list<uint32_t> dims)
+    {
+      // TODO Uncomment this line when GraphInput is ready
+      // _graph_input->shape(dims)
+      _node->shape(dims);
+      return this;
+    }
+
+  private:
+    loco::GraphInput *_input = nullptr;
+    loco::Pull *_node = nullptr;
+  };
+
+  std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+  {
+    auto input_index = ctx->graph()->inputs()->size();
+    auto graph_input = ctx->graph()->inputs()->create();
+
+    auto pull_node = ctx->graph()->nodes()->create<loco::Pull>();
+
+    pull_node->index(input_index);
+
+    loco::link(graph_input, pull_node);
+
+    ctx->stack()->push(pull_node);
+
+    return stdex::make_unique<Return>(graph_input, pull_node);
+  }
+};
+
+// "OutputLayer" creates both GraphOutput and Push node at once.
+struct OutputLayer final
+{
+  class Return
+  {
+  public:
+    Return(loco::GraphOutput *output, loco::Push *node) : _output{output}, _node{node}
+    {
+      // DO NOTHING
+    }
+
+  public:
+    loco::Push *node(void) { return _node; }
+
+  public:
+    Return *name(const std::string &value)
+    {
+      // TODO Uncomment this line when GraphOutput is ready
+      // _graph_output->shape(dims)
+      _output->name(value);
+      return this;
+    }
+
+  private:
+    loco::GraphOutput *_output = nullptr;
+    loco::Push *_node = nullptr;
+  };
+
+  std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+  {
+    auto output_index = ctx->graph()->outputs()->size();
+    auto graph_output = ctx->graph()->outputs()->create();
+
+    auto push_node = ctx->graph()->nodes()->create<loco::Push>();
+
+    push_node->from(ctx->stack()->pop());
+    push_node->index(output_index);
+
+    loco::link(graph_output, push_node);
+
+    ctx->stack()->push(push_node);
+
+    return stdex::make_unique<Return>(graph_output, push_node);
+  }
+};
+
 #endif // __GRAPH_BUILDER_H__
index e3fce2b..37cac41 100644 (file)
@@ -4,6 +4,8 @@
 #include "loco/IR/Graph.h"
 #include "loco/IR/PermutingCodec.h"
 
+#include "GraphBuilder.h"
+
 #include <stdex/Memory.h>
 
 enum class GraphCode
@@ -27,6 +29,13 @@ private:
     // Create a sample network
     _graph = loco::make_graph();
 
+    auto graph_builder = make_graph_builder(_graph.get());
+
+    pull_node = graph_builder->push<InputLayer>()->name("input")->shape(dims)->node();
+    push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+
+// TODO Remove deprecated code
+#if 0
     // Create Nodes
     pull_node = _graph->nodes()->create<loco::Pull>();
     pull_node->shape(dims);
@@ -56,6 +65,7 @@ private:
     // graph->output->shape(dims)
     loco::link(graph_output, push_node);
     push_node->index(0);
+#endif
   }
 
 public: