[exo-tflite] Support ReLU6 (#5884)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 25 Jul 2019 07:23:11 +0000 (16:23 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Thu, 25 Jul 2019 07:23:11 +0000 (16:23 +0900)
TFLExporter is now able to export ReLU6 node as T/F Lite RELU6
operation.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/exo-tflite/src/OperationExporter.cpp
compiler/exo-tflite/src/ShapeInference.cpp
compiler/exo-tflite/src/TFLExporterImpl.test.cpp

index d88f1f5..22f00f9 100644 (file)
@@ -38,6 +38,7 @@ public:
 
 public:
   void visit(loco::ReLU *) final;
+  void visit(loco::ReLU6 *) final;
   void visit(loco::Push *) final { /* DO NOTHING */}
   void visit(loco::Pull *) final { /* DO NOTHING */}
   void visit(loco::FeatureEncode *) final;
@@ -68,6 +69,17 @@ void OperationExporter::visit(loco::ReLU *node)
   gd._operators.push_back(op_offset);
 }
 
+void OperationExporter::visit(loco::ReLU6 *node)
+{
+  uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU6);
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->input())};
+  std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs);
+  gd._operators.push_back(op_offset);
+}
+
 void OperationExporter::visit(loco::MaxPool2D *node)
 {
   uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_MAX_POOL_2D);
index 968aa13..3f91531 100644 (file)
@@ -81,6 +81,8 @@ public:
   NODE(FeatureBiasAdd)
 #undef NODE
   // TODO Put all the visit method implementations inside this class declaration
+  ShapeDescription visit(loco::ReLU6 *node) { return gd._node_to_shape[node->input()]; }
+
 private:
   ShapeContext &gd;
 };
index 08fd90d..3d9b921 100644 (file)
@@ -78,6 +78,39 @@ template <> loco::FeatureDecode *TFLExporterImplTests::make_node(void)
 
 } // namespace
 
+TEST_F(TFLExporterImplTests, Relu6)
+{
+  auto pull = make_node<loco::Pull>();
+  {
+    pull->dtype(loco::DataType::FLOAT32);
+    pull->shape({1, 8, 8, 3});
+  }
+  auto relu6 = make_node<loco::ReLU6>();
+  {
+    relu6->input(pull);
+  }
+  auto push = make_node<loco::Push>();
+  {
+    push->from(relu6);
+  }
+
+  auto input = graph()->inputs()->create();
+  {
+    input->name("input");
+    input->node(pull);
+  }
+  auto output = graph()->outputs()->create();
+  {
+    output->name("output");
+    output->node(push);
+  }
+
+  exo::TFLExporter::Impl exporter{graph()};
+
+  // TODO Add more checks
+  SUCCEED();
+}
+
 /**
  * What happens when there is a mismatch between generation and execution order!?
  */