From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Mon, 5 Aug 2019 02:50:39 +0000 (+0900) Subject: [loco] Implement Relu shape inference (#6161) X-Git-Tag: submit/tizen/20190809.050447~197 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=15da670e646d2fa4802e683e883d48d5b77a4477;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Implement Relu shape inference (#6161) * [loco] Implement Relu shape inference CanonicalShapeInferenceRule now properly infers the shape of relu node. Signed-off-by: Jonghyun Park * Add final --- diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 77d887e..e18d4b2 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -78,7 +78,9 @@ public: return loco::NodeShape{tensor_shape}; } - // TODO Support ReLU + // CASE: ReLU + loco::NodeShape visit(const loco::ReLU *node) final { return loco::shape_get(node->input()); } + // TODO Support ReLU6 // TODO Support TensorBiasAdd // TODO SUpport TensorConcat diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp index c0e3b3f..537c26c 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -44,3 +44,25 @@ TEST(CanonicalShapeInferenceRuleTest, minimal) ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(2), 3); ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(3), 4); } + +TEST(CanonicalShapeInferenceRuleTest, relu) +{ + // Create a sample network + GraphTestcase testcase; + + testcase.pull_node->shape({1, 2, 3, 4}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().rank(), 4); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(1), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(2), 3); + ASSERT_EQ(loco::shape_get(testcase.push_node).as().dim(3), 4); +} diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h index 27cd90e..4ae4101 100644 --- a/compiler/loco/src/Service/GraphTestcase.h +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -6,6 +6,7 @@ enum class GraphCode { Identity, + Relu, }; template class GraphTestcase; @@ -49,4 +50,47 @@ private: std::unique_ptr _graph; }; +template <> class GraphTestcase final +{ +public: + GraphTestcase() + { + // Create a sample network + _graph = loco::make_graph(); + + // Create Nodes + pull_node = _graph->nodes()->create(); + + relu_node = _graph->nodes()->create(); + relu_node->input(pull_node); + + push_node = _graph->nodes()->create(); + push_node->from(relu_node); + + // Create Graph Input + auto graph_input = _graph->inputs()->create(); + + graph_input->name("input"); + graph_input->node(pull_node); + pull_node->index(0); + + // Create Graph Output + auto graph_output = _graph->outputs()->create(); + + graph_output->name("output"); + graph_output->node(push_node); + push_node->index(0); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::ReLU *relu_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr _graph; +}; + #endif // __GRAPH_TESTCASE_H__