[loco] Support ReLU6 type inference (#5877)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 25 Jul 2019 06:21:25 +0000 (15:21 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 25 Jul 2019 06:21:25 +0000 (15:21 +0900)
Now, type inference framework supports ReLU6 node.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/loco/src/Service/TypeInference.cpp
compiler/loco/src/Service/TypeInference.test.cpp

index 93e6031..a993938 100644 (file)
@@ -110,7 +110,7 @@ struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<l
   loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); }
   loco::DataType visit(const loco::Pull *node) { return node->dtype(); }
   loco::DataType visit(const loco::ReLU *node) { return loco::dtype_get(node->input()); }
-  // TODO Support ReLU6
+  loco::DataType visit(const loco::ReLU6 *node) { return loco::dtype_get(node->input()); }
   loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
   loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
 };
index 84c4cbf..238dde3 100644 (file)
@@ -126,3 +126,40 @@ TEST(CanonicalTypeInferenceRuleTest, minimal)
   ASSERT_TRUE(loco::dtype_known(push_node));
   ASSERT_EQ(loco::dtype_get(push_node), loco::DataType::U8);
 }
+
+TEST(CanonicalTypeInferenceRuleTest, relu6)
+{
+  // Create a simple Relu6 network
+  auto g = loco::make_graph();
+
+  auto pull_node = g->nodes()->create<loco::Pull>();
+
+  pull_node->dtype(loco::DataType::FLOAT32);
+
+  auto relu6_node = g->nodes()->create<loco::ReLU6>();
+
+  relu6_node->input(pull_node);
+
+  auto push_node = g->nodes()->create<loco::Push>();
+
+  push_node->from(relu6_node);
+
+  auto graph_input = g->inputs()->create();
+
+  graph_input->name("input");
+  graph_input->node(pull_node);
+
+  auto graph_output = g->outputs()->create();
+
+  graph_output->name("output");
+  graph_output->node(push_node);
+
+  // Run Type Inference
+  loco::CanonicalTypeInferenceRule rule;
+
+  loco::apply(&rule).to(g.get());
+
+  // Verify!
+  ASSERT_TRUE(loco::dtype_known(relu6_node));
+  ASSERT_EQ(loco::dtype_get(relu6_node), loco::DataType::FLOAT32);
+}