[tflite_loader] implement squared_difference (#7127)
author이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 <sg5.lee@samsung.com>
Tue, 3 Sep 2019 08:23:17 +0000 (17:23 +0900)
committer이춘석/On-Device Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Tue, 3 Sep 2019 08:23:17 +0000 (17:23 +0900)
It adds squared_difference operator in tflite_loader.

Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
runtimes/neurun/frontend/tflite/loader.cc
runtimes/neurun/frontend/tflite/loader.h

index ba95cc8..9e0e8fb 100644 (file)
@@ -385,6 +385,18 @@ void Loader::loadRelu6(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
+void Loader::loadSquaredDifference(const tflite::Operator *op)
+{
+  model::OperandIndexSequence inputs;
+  model::OperandIndexSequence outputs;
+
+  loadOperationIO(op, inputs, outputs);
+
+  std::unique_ptr<model::Operation> new_op(
+      new model::operation::SquaredDifferenceNode(inputs, outputs));
+  _graph.addOperation(std::move(new_op));
+}
+
 void Loader::loadTanh(const tflite::Operator *op)
 {
   model::OperandIndexSequence inputs;
@@ -472,15 +484,18 @@ void Loader::loadOperation(const tflite::Operator *op)
     case BuiltinOperator_RELU6:
       loadRelu6(op);
       return;
+    case BuiltinOperator_RSQRT:
+      loadRsqrt(op);
+      return;
+    case BuiltinOperator_SQUARED_DIFFERENCE:
+      loadSquaredDifference(op);
+      return;
     case BuiltinOperator_TANH:
       loadTanh(op);
       return;
     case BuiltinOperator_TRANSPOSE:
       loadTranspose(op);
       return;
-    case BuiltinOperator_RSQRT:
-      loadRsqrt(op);
-      return;
     default:
       auto *names = EnumNamesBuiltinOperator();
       int enum_value = static_cast<int>(_op_code_to_builtin_op[op->opcode_index()]);
index 2d7d44d..01f6614 100644 (file)
@@ -88,8 +88,9 @@ private:
   void loadRelu(const tflite::Operator *op);
   void loadRelu6(const tflite::Operator *op);
   void loadRsqrt(const tflite::Operator *op);
-  void loadTranspose(const tflite::Operator *op);
+  void loadSquaredDifference(const tflite::Operator *op);
   void loadTanh(const tflite::Operator *op);
+  void loadTranspose(const tflite::Operator *op);
 
 private:
   // Buffer for loading (if needed)