From: Ivan Vagin/AI Tools Lab /SRR/Engineer/삼성전자 Date: Mon, 9 Sep 2019 03:04:41 +0000 (+0300) Subject: [tflite_loader] Implemented TransposeConv node conversion (#7268) X-Git-Tag: accepted/tizen/unified/20190911.111615~45 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=991de073767ff71dcb203dd76af4843c51fc22d9;p=platform%2Fcore%2Fml%2Fnnfw.git [tflite_loader] Implemented TransposeConv node conversion (#7268) * [tflite_loader] Implemented TransposeConv node conversion Implemented transpose conv node conversion Signed-off-by: Ivan Vagin * Fixed formatting Signed-off-by: Ivan Vagin --- diff --git a/runtimes/neurun/frontend/tflite/loader.cc b/runtimes/neurun/frontend/tflite/loader.cc index cc1d522..844365b 100644 --- a/runtimes/neurun/frontend/tflite/loader.cc +++ b/runtimes/neurun/frontend/tflite/loader.cc @@ -212,6 +212,21 @@ void Loader::loadDepthwiseConv2D(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } +void Loader::loadTransposeConv(const tflite::Operator *op) +{ + model::OperandIndexSequence inputs; + model::OperandIndexSequence outputs; + + loadOperationIO(op, inputs, outputs); + + model::operation::TransposeConvNode::Param param; + const auto *options = op->builtin_options_as_TransposeConvOptions(); + loadStridesAndPaddings(param, options); + std::unique_ptr new_op( + new model::operation::TransposeConvNode(inputs, outputs, param)); + _graph.addOperation(std::move(new_op)); +} + void Loader::loadAvgPool2D(const tflite::Operator *op) { model::OperandIndexSequence inputs; @@ -535,6 +550,9 @@ void Loader::loadOperation(const tflite::Operator *op) case BuiltinOperator_DEPTHWISE_CONV_2D: loadDepthwiseConv2D(op); return; + case BuiltinOperator_TRANSPOSE_CONV: + loadTransposeConv(op); + return; case BuiltinOperator_RESHAPE: loadReshape(op); return; diff --git a/runtimes/neurun/frontend/tflite/loader.h b/runtimes/neurun/frontend/tflite/loader.h index 0ccd948..c398cbc 100644 --- a/runtimes/neurun/frontend/tflite/loader.h +++ b/runtimes/neurun/frontend/tflite/loader.h @@ -76,6 +76,7 @@ private: // Operations void loadConv2D(const tflite::Operator *op); void loadDepthwiseConv2D(const tflite::Operator *op); + void loadTransposeConv(const tflite::Operator *op); void loadAvgPool2D(const tflite::Operator *op); void loadReshape(const tflite::Operator *op); void loadSoftmax(const tflite::Operator *op);