From 9ed115dc4710ae4417ae4b870023ec139f57695a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=B2=9C=EA=B5=90/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 14 Aug 2019 10:56:31 +0900 Subject: [PATCH] [loco] FixedReshape graph testcase (#6545) This commit adds graph builder layer and testcase for FixedReshape Signed-off-by: Cheongyo Bahk --- compiler/loco/src/Service/GraphBuilder.h | 36 +++++++++++++++++++++++++++++++ compiler/loco/src/Service/GraphTestcase.h | 26 ++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/compiler/loco/src/Service/GraphBuilder.h b/compiler/loco/src/Service/GraphBuilder.h index e47782f..20b06de 100644 --- a/compiler/loco/src/Service/GraphBuilder.h +++ b/compiler/loco/src/Service/GraphBuilder.h @@ -321,4 +321,40 @@ struct FeatureDecodeLayer final } }; +struct FixedReshapeLayer final +{ + class Return + { + public: + Return(loco::FixedReshape *node) : _node{node} + { + // DO NOTHING + } + + public: + Return *shape(std::initializer_list dims) + { + _node->shape(dims); + return this; + } + + public: + loco::FixedReshape *node(void) { return _node; } + + private: + loco::FixedReshape *_node = nullptr; + }; + + std::unique_ptr operator()(GraphBuilder::Context *ctx) + { + auto reshape_node = ctx->graph()->nodes()->create(); + + reshape_node->input(ctx->stack()->pop()); + + ctx->stack()->push(reshape_node); + + return stdex::make_unique(reshape_node); + } +}; + #endif // __GRAPH_BUILDER_H__ diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index c848ba0..9f1004f 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -17,6 +17,7 @@ enum class GraphCode AvgPool2D, MaxPool2D, TensorConcat, + FixedReshape, }; template class GraphTestcase; @@ -399,6 +400,31 @@ private: std::unique_ptr _graph; }; +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + _graph = loco::make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push()->name("input")->node(); + reshape_node = graph_builder->push()->node(); + push_node = graph_builder->push()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FixedReshape *reshape_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + namespace { -- 2.7.4