From 30ce785f3d36131fd94155d5c4a0ed1a1b9b5735 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=EC=83=81=EA=B7=9C/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 4 Sep 2019 10:42:04 +0900 Subject: [PATCH] [tflite_loader] implement div and sqrt (#7114) It adds div and sqrt in tflite_loader. `tf2tflite` converts `rsqrt` into `sqrt` and `div`. Signed-off-by: Sanggyu Lee --- runtimes/neurun/frontend/tflite/loader.cc | 71 ++++++++++++++++++++++--------- runtimes/neurun/frontend/tflite/loader.h | 4 +- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/runtimes/neurun/frontend/tflite/loader.cc b/runtimes/neurun/frontend/tflite/loader.cc index 9e0e8fb..09ed02c 100644 --- a/runtimes/neurun/frontend/tflite/loader.cc +++ b/runtimes/neurun/frontend/tflite/loader.cc @@ -331,23 +331,23 @@ void Loader::loadAdd(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void Loader::loadMul(const tflite::Operator *op) +void Loader::loadSub(const tflite::Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; loadOperationIO(op, inputs, outputs); - model::operation::MulNode::Param param; - const auto *options = op->builtin_options_as_MulOptions(); + model::operation::SubNode::Param param; + const auto *options = op->builtin_options_as_SubOptions(); - param.activation = convertActivation(options->fused_activation_function()); + param.activation = neurun::model::Activation(options->fused_activation_function()); - std::unique_ptr new_op(new model::operation::MulNode(inputs, outputs, param)); + std::unique_ptr new_op(new model::operation::SubNode(inputs, outputs, param)); _graph.addOperation(std::move(new_op)); } -void Loader::loadSub(const tflite::Operator *op) +void Loader::loadMul(const tflite::Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -363,6 +363,22 @@ void Loader::loadSub(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } +void Loader::loadDiv(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + model::operation::DivNode::Param param; + const auto *options = op->builtin_options_as_DivOptions(); + + param.activation = convertActivation(options->fused_activation_function()); + + std::unique_ptr new_op(new model::operation::DivNode(inputs, outputs, param)); + _graph.addOperation(std::move(new_op)); +} + void Loader::loadRelu(const tflite::Operator *op) { model::OperandIndexSequence inputs; @@ -385,6 +401,28 @@ void Loader::loadRelu6(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } +void Loader::loadRsqrt(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + std::unique_ptr new_op(new model::operation::RSQRTNode(inputs, outputs)); + _graph.addOperation(std::move(new_op)); +} + +void Loader::loadSqrt(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + std::unique_ptr new_op(new model::operation::SQRTNode(inputs, outputs)); + _graph.addOperation(std::move(new_op)); +} + void Loader::loadSquaredDifference(const tflite::Operator *op) { model::OperandIndexSequence inputs; @@ -430,17 +468,6 @@ void Loader::loadTranspose(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void Loader::loadRsqrt(const tflite::Operator *op) -{ - model::OperandIndexSequence inputs; - model::OperandIndexSequence outputs; - - loadOperationIO(op, inputs, outputs); - - std::unique_ptr new_op(new model::operation::RSQRTNode(inputs, outputs)); - _graph.addOperation(std::move(new_op)); -} - void Loader::loadOperation(const tflite::Operator *op) { switch (_op_code_to_builtin_op[op->opcode_index()]) @@ -472,11 +499,14 @@ void Loader::loadOperation(const tflite::Operator *op) case BuiltinOperator_ADD: loadAdd(op); return; + case BuiltinOperator_SUB: + loadSub(op); + return; case BuiltinOperator_MUL: loadMul(op); return; - case BuiltinOperator_SUB: - loadSub(op); + case BuiltinOperator_DIV: + loadDiv(op); return; case BuiltinOperator_RELU: loadRelu(op); @@ -487,6 +517,9 @@ void Loader::loadOperation(const tflite::Operator *op) case BuiltinOperator_RSQRT: loadRsqrt(op); return; + case BuiltinOperator_SQRT: + loadSqrt(op); + return; case BuiltinOperator_SQUARED_DIFFERENCE: loadSquaredDifference(op); return; diff --git a/runtimes/neurun/frontend/tflite/loader.h b/runtimes/neurun/frontend/tflite/loader.h index 01f6614..a13f8ba 100644 --- a/runtimes/neurun/frontend/tflite/loader.h +++ b/runtimes/neurun/frontend/tflite/loader.h @@ -83,11 +83,13 @@ private: void loadConcatenation(const tflite::Operator *op); void loadFC(const tflite::Operator *op); void loadAdd(const tflite::Operator *op); - void loadMul(const tflite::Operator *op); void loadSub(const tflite::Operator *op); + void loadMul(const tflite::Operator *op); + void loadDiv(const tflite::Operator *op); void loadRelu(const tflite::Operator *op); void loadRelu6(const tflite::Operator *op); void loadRsqrt(const tflite::Operator *op); + void loadSqrt(const tflite::Operator *op); void loadSquaredDifference(const tflite::Operator *op); void loadTanh(const tflite::Operator *op); void loadTranspose(const tflite::Operator *op); -- 2.7.4