[loco] Introduce constgen layer (#6645)
author채성우/On-Device Lab(SR)/Engineer/삼성전자 <sw4670.chae@samsung.com>
Sun, 18 Aug 2019 23:29:05 +0000 (08:29 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Sun, 18 Aug 2019 23:29:05 +0000 (08:29 +0900)
This commit introduce constgen layer to GraphBuilder.

Signed-off-by: seongwoo <sw4670.chae@samsung.com>
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
compiler/loco/src/Service/GraphBuilder.h
compiler/loco/src/Service/GraphTestcase.h

index 97114d3..ac07b3f 100644 (file)
@@ -48,6 +48,9 @@ TEST(CanonicalShapeInferenceRuleTest, const_gen)
   // Create a sample network
   GraphTestcase<GraphCode::ConstGen> testcase;
 
+  testcase.const_node->dtype(loco::DataType::FLOAT32);
+  testcase.const_node->shape({1, 2});
+
   // Run Inference
   loco::CanonicalShapeInferenceRule rule;
 
index 20b06de..128763d 100644 (file)
@@ -240,6 +240,33 @@ struct ReLULayer final
   }
 };
 
+struct ConstGenLayer final
+{
+  class Return
+  {
+  public:
+    Return(loco::ConstGen *node) : _node{node}
+    {
+      // DO NOTHING
+    }
+
+  public:
+    loco::ConstGen *node(void) { return _node; }
+
+  private:
+    loco::ConstGen *_node = nullptr;
+  };
+
+  std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+  {
+    auto const_node = ctx->graph()->nodes()->create<loco::ConstGen>();
+
+    ctx->stack()->push(const_node);
+
+    return stdex::make_unique<Return>(const_node);
+  }
+};
+
 #include "loco/IR/PermutingCodec.h"
 
 struct FeatureEncodeLayer final
index 40a4435..e515746 100644 (file)
@@ -56,26 +56,13 @@ template <> class GraphTestcase<GraphCode::ConstGen> final
 public:
   GraphTestcase()
   {
-    // Create a sample network
     _graph = loco::make_graph();
 
-    // Create Nodes
-    const_node = _graph->nodes()->create<loco::ConstGen>();
-    const_node->dtype(loco::DataType::FLOAT32);
-    const_node->shape({1, 2});
-
-    push_node = _graph->nodes()->create<loco::Push>();
-    push_node->from(const_node);
-
-    // Create Graph Output
-    auto graph_output = _graph->outputs()->create();
+    auto graph_builder = make_graph_builder(_graph.get());
 
-    graph_output->name("output");
-    graph_output->dtype(loco::DataType::FLOAT32);
-    graph_output->shape({1, 2});
+    const_node = graph_builder->push<ConstGenLayer>()->node();
 
-    loco::link(graph_output, push_node);
-    push_node->index(0);
+    push_node = graph_builder->push<OutputLayer>()->name("output")->node();
   }
 
 public: