[loco] FixedReshape graph testcase (#6545)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Wed, 14 Aug 2019 01:56:31 +0000 (10:56 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 14 Aug 2019 01:56:31 +0000 (10:56 +0900)
This commit adds graph builder layer and testcase for FixedReshape

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/loco/src/Service/GraphBuilder.h
compiler/loco/src/Service/GraphTestcase.h

index e47782f..20b06de 100644 (file)
@@ -321,4 +321,40 @@ struct FeatureDecodeLayer final
   }
 };
 
+struct FixedReshapeLayer final
+{
+  class Return
+  {
+  public:
+    Return(loco::FixedReshape *node) : _node{node}
+    {
+      // DO NOTHING
+    }
+
+  public:
+    Return *shape(std::initializer_list<uint32_t> dims)
+    {
+      _node->shape(dims);
+      return this;
+    }
+
+  public:
+    loco::FixedReshape *node(void) { return _node; }
+
+  private:
+    loco::FixedReshape *_node = nullptr;
+  };
+
+  std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+  {
+    auto reshape_node = ctx->graph()->nodes()->create<loco::FixedReshape>();
+
+    reshape_node->input(ctx->stack()->pop());
+
+    ctx->stack()->push(reshape_node);
+
+    return stdex::make_unique<Return>(reshape_node);
+  }
+};
+
 #endif // __GRAPH_BUILDER_H__
index c848ba0..9f1004f 100644 (file)
@@ -17,6 +17,7 @@ enum class GraphCode
   AvgPool2D,
   MaxPool2D,
   TensorConcat,
+  FixedReshape,
 };
 
 template <GraphCode Code> class GraphTestcase;
@@ -399,6 +400,31 @@ private:
   std::unique_ptr<loco::Graph> _graph;
 };
 
+template <> class GraphTestcase<GraphCode::FixedReshape> final
+{
+public:
+  GraphTestcase()
+  {
+    _graph = loco::make_graph();
+
+    auto graph_builder = make_graph_builder(_graph.get());
+
+    pull_node = graph_builder->push<InputLayer>()->name("input")->node();
+    reshape_node = graph_builder->push<FixedReshapeLayer>()->node();
+    push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+  }
+
+public:
+  loco::Graph *graph() { return _graph.get(); }
+
+  loco::Pull *pull_node = nullptr;
+  loco::FixedReshape *reshape_node = nullptr;
+  loco::Push *push_node = nullptr;
+
+private:
+  std::unique_ptr<loco::Graph> _graph;
+};
+
 namespace
 {