[loco] Implement Relu shape inference (#6161)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 5 Aug 2019 02:50:39 +0000 (11:50 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 5 Aug 2019 02:50:39 +0000 (11:50 +0900)
* [loco] Implement Relu shape inference

CanonicalShapeInferenceRule now properly infers the shape of relu node.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Add final

compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
compiler/loco/src/Service/GraphTestcase.h

index 77d887e..e18d4b2 100644 (file)
@@ -78,7 +78,9 @@ public:
     return loco::NodeShape{tensor_shape};
   }
 
-  // TODO Support ReLU
+  // CASE: ReLU
+  loco::NodeShape visit(const loco::ReLU *node) final { return loco::shape_get(node->input()); }
+
   // TODO Support ReLU6
   // TODO Support TensorBiasAdd
   // TODO SUpport TensorConcat
index c0e3b3f..537c26c 100644 (file)
@@ -44,3 +44,25 @@ TEST(CanonicalShapeInferenceRuleTest, minimal)
   ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
   ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
 }
+
+TEST(CanonicalShapeInferenceRuleTest, relu)
+{
+  // Create a sample network
+  GraphTestcase<GraphCode::Relu> testcase;
+
+  testcase.pull_node->shape({1, 2, 3, 4});
+
+  // Run Inference
+  loco::CanonicalShapeInferenceRule rule;
+
+  loco::apply(&rule).to(testcase.graph());
+
+  // Verify!
+  ASSERT_TRUE(loco::shape_known(testcase.push_node));
+  ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
+  ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
+}
index 27cd90e..4ae4101 100644 (file)
@@ -6,6 +6,7 @@
 enum class GraphCode
 {
   Identity,
+  Relu,
 };
 
 template <GraphCode Code> class GraphTestcase;
@@ -49,4 +50,47 @@ private:
   std::unique_ptr<loco::Graph> _graph;
 };
 
+template <> class GraphTestcase<GraphCode::Relu> final
+{
+public:
+  GraphTestcase()
+  {
+    // Create a sample network
+    _graph = loco::make_graph();
+
+    // Create Nodes
+    pull_node = _graph->nodes()->create<loco::Pull>();
+
+    relu_node = _graph->nodes()->create<loco::ReLU>();
+    relu_node->input(pull_node);
+
+    push_node = _graph->nodes()->create<loco::Push>();
+    push_node->from(relu_node);
+
+    // Create Graph Input
+    auto graph_input = _graph->inputs()->create();
+
+    graph_input->name("input");
+    graph_input->node(pull_node);
+    pull_node->index(0);
+
+    // Create Graph Output
+    auto graph_output = _graph->outputs()->create();
+
+    graph_output->name("output");
+    graph_output->node(push_node);
+    push_node->index(0);
+  }
+
+public:
+  loco::Graph *graph() { return _graph.get(); }
+
+  loco::Pull *pull_node = nullptr;
+  loco::ReLU *relu_node = nullptr;
+  loco::Push *push_node = nullptr;
+
+private:
+  std::unique_ptr<loco::Graph> _graph;
+};
+
 #endif // __GRAPH_TESTCASE_H__