From: 이상규/On-Device Lab(SR)/Principal Engineer/삼성전자 Date: Wed, 4 Sep 2019 01:42:04 +0000 (+0900) Subject: [tflite_loader] implement div and sqrt (#7114) X-Git-Tag: accepted/tizen/unified/20190904.110638~11 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=30ce785f3d36131fd94155d5c4a0ed1a1b9b5735;p=platform%2Fcore%2Fml%2Fnnfw.git [tflite_loader] implement div and sqrt (#7114) It adds div and sqrt in tflite_loader. `tf2tflite` converts `rsqrt` into `sqrt` and `div`. Signed-off-by: Sanggyu Lee --- diff --git a/runtimes/neurun/frontend/tflite/loader.cc b/runtimes/neurun/frontend/tflite/loader.cc index 9e0e8fb..09ed02c 100644 --- a/runtimes/neurun/frontend/tflite/loader.cc +++ b/runtimes/neurun/frontend/tflite/loader.cc @@ -331,23 +331,23 @@ void Loader::loadAdd(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void Loader::loadMul(const tflite::Operator *op) +void Loader::loadSub(const tflite::Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; loadOperationIO(op, inputs, outputs); - model::operation::MulNode::Param param; - const auto *options = op->builtin_options_as_MulOptions(); + model::operation::SubNode::Param param; + const auto *options = op->builtin_options_as_SubOptions(); - param.activation = convertActivation(options->fused_activation_function()); + param.activation = neurun::model::Activation(options->fused_activation_function()); - std::unique_ptr new_op(new model::operation::MulNode(inputs, outputs, param)); + std::unique_ptr new_op(new model::operation::SubNode(inputs, outputs, param)); _graph.addOperation(std::move(new_op)); } -void Loader::loadSub(const tflite::Operator *op) +void Loader::loadMul(const tflite::Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -363,6 +363,22 @@ void Loader::loadSub(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } +void Loader::loadDiv(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + model::operation::DivNode::Param param; + const auto *options = op->builtin_options_as_DivOptions(); + + param.activation = convertActivation(options->fused_activation_function()); + + std::unique_ptr new_op(new model::operation::DivNode(inputs, outputs, param)); + _graph.addOperation(std::move(new_op)); +} + void Loader::loadRelu(const tflite::Operator *op) { model::OperandIndexSequence inputs; @@ -385,6 +401,28 @@ void Loader::loadRelu6(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } +void Loader::loadRsqrt(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + std::unique_ptr new_op(new model::operation::RSQRTNode(inputs, outputs)); + _graph.addOperation(std::move(new_op)); +} + +void Loader::loadSqrt(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + std::unique_ptr new_op(new model::operation::SQRTNode(inputs, outputs)); + _graph.addOperation(std::move(new_op)); +} + void Loader::loadSquaredDifference(const tflite::Operator *op) { model::OperandIndexSequence inputs; @@ -430,17 +468,6 @@ void Loader::loadTranspose(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void Loader::loadRsqrt(const tflite::Operator *op) -{ - model::OperandIndexSequence inputs; - model::OperandIndexSequence outputs; - - loadOperationIO(op, inputs, outputs); - - std::unique_ptr new_op(new model::operation::RSQRTNode(inputs, outputs)); - _graph.addOperation(std::move(new_op)); -} - void Loader::loadOperation(const tflite::Operator *op) { switch (_op_code_to_builtin_op[op->opcode_index()]) @@ -472,11 +499,14 @@ void Loader::loadOperation(const tflite::Operator *op) case BuiltinOperator_ADD: loadAdd(op); return; + case BuiltinOperator_SUB: + loadSub(op); + return; case BuiltinOperator_MUL: loadMul(op); return; - case BuiltinOperator_SUB: - loadSub(op); + case BuiltinOperator_DIV: + loadDiv(op); return; case BuiltinOperator_RELU: loadRelu(op); @@ -487,6 +517,9 @@ void Loader::loadOperation(const tflite::Operator *op) case BuiltinOperator_RSQRT: loadRsqrt(op); return; + case BuiltinOperator_SQRT: + loadSqrt(op); + return; case BuiltinOperator_SQUARED_DIFFERENCE: loadSquaredDifference(op); return; diff --git a/runtimes/neurun/frontend/tflite/loader.h b/runtimes/neurun/frontend/tflite/loader.h index 01f6614..a13f8ba 100644 --- a/runtimes/neurun/frontend/tflite/loader.h +++ b/runtimes/neurun/frontend/tflite/loader.h @@ -83,11 +83,13 @@ private: void loadConcatenation(const tflite::Operator *op); void loadFC(const tflite::Operator *op); void loadAdd(const tflite::Operator *op); - void loadMul(const tflite::Operator *op); void loadSub(const tflite::Operator *op); + void loadMul(const tflite::Operator *op); + void loadDiv(const tflite::Operator *op); void loadRelu(const tflite::Operator *op); void loadRelu6(const tflite::Operator *op); void loadRsqrt(const tflite::Operator *op); + void loadSqrt(const tflite::Operator *op); void loadSquaredDifference(const tflite::Operator *op); void loadTanh(const tflite::Operator *op); void loadTranspose(const tflite::Operator *op);