From 90a2cd07ead4724fd75773a1524af1c3e1e77af9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=EC=83=81=EA=B7=9C/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 3 Sep 2019 17:23:17 +0900 Subject: [PATCH] [tflite_loader] implement squared_difference (#7127) It adds squared_difference operator in tflite_loader. Signed-off-by: Sanggyu Lee --- runtimes/neurun/frontend/tflite/loader.cc | 21 ++++++++++++++++++--- runtimes/neurun/frontend/tflite/loader.h | 3 ++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/runtimes/neurun/frontend/tflite/loader.cc b/runtimes/neurun/frontend/tflite/loader.cc index ba95cc8..9e0e8fb 100644 --- a/runtimes/neurun/frontend/tflite/loader.cc +++ b/runtimes/neurun/frontend/tflite/loader.cc @@ -385,6 +385,18 @@ void Loader::loadRelu6(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } +void Loader::loadSquaredDifference(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + std::unique_ptr new_op( + new model::operation::SquaredDifferenceNode(inputs, outputs)); + _graph.addOperation(std::move(new_op)); +} + void Loader::loadTanh(const tflite::Operator *op) { model::OperandIndexSequence inputs; @@ -472,15 +484,18 @@ void Loader::loadOperation(const tflite::Operator *op) case BuiltinOperator_RELU6: loadRelu6(op); return; + case BuiltinOperator_RSQRT: + loadRsqrt(op); + return; + case BuiltinOperator_SQUARED_DIFFERENCE: + loadSquaredDifference(op); + return; case BuiltinOperator_TANH: loadTanh(op); return; case BuiltinOperator_TRANSPOSE: loadTranspose(op); return; - case BuiltinOperator_RSQRT: - loadRsqrt(op); - return; default: auto *names = EnumNamesBuiltinOperator(); int enum_value = static_cast(_op_code_to_builtin_op[op->opcode_index()]); diff --git a/runtimes/neurun/frontend/tflite/loader.h b/runtimes/neurun/frontend/tflite/loader.h index 2d7d44d..01f6614 100644 --- a/runtimes/neurun/frontend/tflite/loader.h +++ b/runtimes/neurun/frontend/tflite/loader.h @@ -88,8 +88,9 @@ private: void loadRelu(const tflite::Operator *op); void loadRelu6(const tflite::Operator *op); void loadRsqrt(const tflite::Operator *op); - void loadTranspose(const tflite::Operator *op); + void loadSquaredDifference(const tflite::Operator *op); void loadTanh(const tflite::Operator *op); + void loadTranspose(const tflite::Operator *op); private: // Buffer for loading (if needed) -- 2.7.4