From fee76ce3e289ab74fc552f432da06b4043316c5c Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 8 Aug 2019 16:39:37 +0900 Subject: [PATCH] [loco] Simplify ReLU Graph Testcase (#6392) Let's simplify ReLU Graph Testcase using GraphBuilder. Signed-off-by: Jonghyun Park --- compiler/loco/src/Service/GraphBuilder.h | 31 +++++++++++++++++++++++++++++++ compiler/loco/src/Service/GraphTestcase.h | 12 ++++++++++++ 2 files changed, 43 insertions(+) diff --git a/compiler/loco/src/Service/GraphBuilder.h b/compiler/loco/src/Service/GraphBuilder.h index 8c48755..5c517a3 100644 --- a/compiler/loco/src/Service/GraphBuilder.h +++ b/compiler/loco/src/Service/GraphBuilder.h @@ -209,4 +209,35 @@ struct OutputLayer final } }; +struct ReLULayer final +{ + // This "Return" is unnecessary for ReLU as ReLU has no attributes), but + // introduced for consistency. + class Return + { + public: + Return(loco::ReLU *node) : _node{node} + { + // DO NOTHING + } + + public: + loco::ReLU *node(void) { return _node; } + + private: + loco::ReLU *_node = nullptr; + }; + + std::unique_ptr operator()(GraphBuilder::Context *ctx) + { + auto relu_node = ctx->graph()->nodes()->create(); + + relu_node->input(ctx->stack()->pop()); + + ctx->stack()->push(relu_node); + + return stdex::make_unique(relu_node); + } +}; + #endif // __GRAPH_BUILDER_H__ diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 37cac41..ac23c59 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -128,6 +128,17 @@ public: // Create a sample network _graph = loco::make_graph(); + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push()->name("input")->node(); + relu_node = graph_builder->push()->node(); + push_node = graph_builder->push()->name("output")->node(); + +// TODO Remove deprecated code +#if 0 + // Create a sample network + _graph = loco::make_graph(); + // Create Nodes pull_node = _graph->nodes()->create(); @@ -150,6 +161,7 @@ public: graph_output->name("output"); loco::link(graph_output, push_node); push_node->index(0); +#endif } public: -- 2.7.4