[tflite_loader] implement div and sqrt (#7114)
author이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Wed, 4 Sep 2019 01:42:04 +0000 (10:42 +0900)
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Wed, 4 Sep 2019 01:42:04 +0000 (10:42 +0900)
It adds div and sqrt in tflite_loader.
`tf2tflite` converts `rsqrt` into `sqrt` and `div`.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
runtimes/neurun/frontend/tflite/loader.cc
runtimes/neurun/frontend/tflite/loader.h

index 9e0e8fb..09ed02c 100644 (file)
@@ -331,23 +331,23 @@ void Loader::loadAdd(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void Loader::loadMul(const tflite::Operator *op)
+void Loader::loadSub(const tflite::Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
 
   loadOperationIO(op, inputs, outputs);
 
-  model::operation::MulNode::Param param;
-  const auto *options = op->builtin_options_as_MulOptions();
+  model::operation::SubNode::Param param;
+  const auto *options = op->builtin_options_as_SubOptions();
 
-  param.activation = convertActivation(options->fused_activation_function());
+  param.activation = neurun::model::Activation(options->fused_activation_function());
 
-  std::unique_ptr<model::Operation> new_op(new model::operation::MulNode(inputs, outputs, param));
+  std::unique_ptr<model::Operation> new_op(new model::operation::SubNode(inputs, outputs, param));
   _graph.addOperation(std::move(new_op));
 }
 
-void Loader::loadSub(const tflite::Operator *op)
+void Loader::loadMul(const tflite::Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -363,6 +363,22 @@ void Loader::loadSub(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
+void Loader::loadDiv(const tflite::Operator *op)
+{
+  model::OperandIndexSequence inputs;
+  model::OperandIndexSequence outputs;
+
+  loadOperationIO(op, inputs, outputs);
+
+  model::operation::DivNode::Param param;
+  const auto *options = op->builtin_options_as_DivOptions();
+
+  param.activation = convertActivation(options->fused_activation_function());
+
+  std::unique_ptr<model::Operation> new_op(new model::operation::DivNode(inputs, outputs, param));
+  _graph.addOperation(std::move(new_op));
+}
+
 void Loader::loadRelu(const tflite::Operator *op)
 {
   model::OperandIndexSequence inputs;
@@ -385,6 +401,28 @@ void Loader::loadRelu6(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
+void Loader::loadRsqrt(const tflite::Operator *op)
+{
+  model::OperandIndexSequence inputs;
+  model::OperandIndexSequence outputs;
+
+  loadOperationIO(op, inputs, outputs);
+
+  std::unique_ptr<model::Operation> new_op(new model::operation::RSQRTNode(inputs, outputs));
+  _graph.addOperation(std::move(new_op));
+}
+
+void Loader::loadSqrt(const tflite::Operator *op)
+{
+  model::OperandIndexSequence inputs;
+  model::OperandIndexSequence outputs;
+
+  loadOperationIO(op, inputs, outputs);
+
+  std::unique_ptr<model::Operation> new_op(new model::operation::SQRTNode(inputs, outputs));
+  _graph.addOperation(std::move(new_op));
+}
+
 void Loader::loadSquaredDifference(const tflite::Operator *op)
 {
   model::OperandIndexSequence inputs;
@@ -430,17 +468,6 @@ void Loader::loadTranspose(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void Loader::loadRsqrt(const tflite::Operator *op)
-{
-  model::OperandIndexSequence inputs;
-  model::OperandIndexSequence outputs;
-
-  loadOperationIO(op, inputs, outputs);
-
-  std::unique_ptr<model::Operation> new_op(new model::operation::RSQRTNode(inputs, outputs));
-  _graph.addOperation(std::move(new_op));
-}
-
 void Loader::loadOperation(const tflite::Operator *op)
 {
   switch (_op_code_to_builtin_op[op->opcode_index()])
@@ -472,11 +499,14 @@ void Loader::loadOperation(const tflite::Operator *op)
     case BuiltinOperator_ADD:
       loadAdd(op);
       return;
+    case BuiltinOperator_SUB:
+      loadSub(op);
+      return;
     case BuiltinOperator_MUL:
       loadMul(op);
       return;
-    case BuiltinOperator_SUB:
-      loadSub(op);
+    case BuiltinOperator_DIV:
+      loadDiv(op);
       return;
     case BuiltinOperator_RELU:
       loadRelu(op);
@@ -487,6 +517,9 @@ void Loader::loadOperation(const tflite::Operator *op)
     case BuiltinOperator_RSQRT:
       loadRsqrt(op);
       return;
+    case BuiltinOperator_SQRT:
+      loadSqrt(op);
+      return;
     case BuiltinOperator_SQUARED_DIFFERENCE:
       loadSquaredDifference(op);
       return;
index 01f6614..a13f8ba 100644 (file)
@@ -83,11 +83,13 @@ private:
   void loadConcatenation(const tflite::Operator *op);
   void loadFC(const tflite::Operator *op);
   void loadAdd(const tflite::Operator *op);
-  void loadMul(const tflite::Operator *op);
   void loadSub(const tflite::Operator *op);
+  void loadMul(const tflite::Operator *op);
+  void loadDiv(const tflite::Operator *op);
   void loadRelu(const tflite::Operator *op);
   void loadRelu6(const tflite::Operator *op);
   void loadRsqrt(const tflite::Operator *op);
+  void loadSqrt(const tflite::Operator *op);
   void loadSquaredDifference(const tflite::Operator *op);
   void loadTanh(const tflite::Operator *op);
   void loadTranspose(const tflite::Operator *op);