From 15da670e646d2fa4802e683e883d48d5b77a4477 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 5 Aug 2019 11:50:39 +0900 Subject: [PATCH] [loco] Implement Relu shape inference (#6161) * [loco] Implement Relu shape inference CanonicalShapeInferenceRule now properly infers the shape of relu node. Signed-off-by: Jonghyun Park * Add final --- .../src/Service/CanonicalShapeInferenceRule.cpp | 4 +- .../Service/CanonicalShapeInferenceRule.test.cpp | 22 +++++++++++ compiler/loco/src/Service/GraphTestcase.h | 44 ++++++++++++++++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 77d887e..e18d4b2 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -78,7 +78,9 @@ public: return loco::NodeShape{tensor_shape}; } - // TODO Support ReLU + // CASE: ReLU + loco::NodeShape visit(const loco::ReLU *node) final { return loco::shape_get(node->input()); } + // TODO Support ReLU6 // TODO Support TensorBiasAdd // TODO SUpport TensorConcat diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index c0e3b3f..537c26c 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -44,3 +44,25 @@ TEST(CanonicalShapeInferenceRuleTest, minimal) ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(2), 3); ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(3), 4); } + +TEST(CanonicalShapeInferenceRuleTest, relu) +{ + // Create a sample network + GraphTestcase testcase; + + testcase.pull_node->shape({1, 2, 3, 4}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().rank(), 4); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(1), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(2), 3); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(3), 4); +} diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 27cd90e..4ae4101 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -6,6 +6,7 @@ enum class GraphCode { Identity, + Relu, }; template class GraphTestcase; @@ -49,4 +50,47 @@ private: std::unique_ptr _graph; }; +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + // Create a sample network + _graph = loco::make_graph(); + + // Create Nodes + pull_node = _graph->nodes()->create(); + + relu_node = _graph->nodes()->create(); + relu_node->input(pull_node); + + push_node = _graph->nodes()->create(); + push_node->from(relu_node); + + // Create Graph Input + auto graph_input = _graph->inputs()->create(); + + graph_input->name("input"); + graph_input->node(pull_node); + pull_node->index(0); + + // Create Graph Output + auto graph_output = _graph->outputs()->create(); + + graph_output->name("output"); + graph_output->node(push_node); + push_node->index(0); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::ReLU *relu_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + #endif // __GRAPH_TESTCASE_H__ -- 2.7.4