From bc0e75f5d7694b31f4ed8e3991cd434ed9eded91 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dmitry=20Mozolev/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 27 Jun 2018 10:19:44 +0300 Subject: [PATCH] Add Model IR operator creator skeleton for TFLite (#354) Add Model IR operator creator skeleton for TFLite This is a Model IR operator creator class skeleton. Contains public interface for creating supported operations. Signed-off-by: Dmitry Mozolev --- contrib/nnc/libs/frontend/tflite/CMakeLists.txt | 1 + .../frontend/tflite/include/tflite_op_creator.h | 79 ++++++++++++++++++++++ .../libs/frontend/tflite/src/tflite_op_creator.cpp | 57 ++++++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h create mode 100644 contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp diff --git a/contrib/nnc/libs/frontend/tflite/CMakeLists.txt b/contrib/nnc/libs/frontend/tflite/CMakeLists.txt index b8d0e55..c4c6e40 100644 --- a/contrib/nnc/libs/frontend/tflite/CMakeLists.txt +++ b/contrib/nnc/libs/frontend/tflite/CMakeLists.txt @@ -15,6 +15,7 @@ FlatBuffers_Generate(FB_GEN set(tflite_importer_sources ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_walker.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_dump_visitor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_ir_visitor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_op_creator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_v3_importer.cpp) file(GLOB tflite_importer_headers include/*.h) list(APPEND tflite_importer_headers ${FB_GEN_SOURCES}) diff --git a/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h b/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h new file mode 100644 index 0000000..6b3e3a7 --- /dev/null +++ b/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h @@ -0,0 +1,79 @@ +#ifndef NNCC_TFLITE_OP_CREATOR_H +#define NNCC_TFLITE_OP_CREATOR_H + +#include +#include +#include +#include +#include + +#include "PluginException.h" +#include "nnc/core/IR/model/graph/graph.h" +#include "nnc/core/IR/model/graph/ir_node.h" +#include "nnc/core/linalg/TensorVariant.h" +#include "nncc/core/ADT/tensor/Shape.h" + +#include "schema_v3.h" +#include "shape_helper.h" + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace tflite +{ + +using namespace v3_tflite; + +namespace ops = nncc::contrib::core::IR::model::ops; +using nncc::contrib::core::IR::model::Graph; +using nncc::contrib::core::IR::model::ADT::INode; +using IrTensor = nncc::contrib::core::ADT::TensorVariant; +using nncc::core::ADT::tensor::Shape; + +class OpCreator +{ +public: + using InputOps = std::vector &; + using InputParams = std::vector> &; + + explicit OpCreator(Graph *g) : graph(g) {}; + + std::vector createConv2D(InputOps inputs, InputParams params, + const Conv2DOptions *opts); + std::vector createDepthConv2D(InputOps inputs, InputParams params, + const DepthwiseConv2DOptions *opts); + std::vector createConcat(InputOps inputs, InputParams params, + const ConcatenationOptions *opts); + std::vector createMaxPool(InputOps inputs, InputParams params, + const Pool2DOptions *opts); + std::vector createAvgPool(InputOps inputs, InputParams params, + const Pool2DOptions *opts); + std::vector createSoftmax(InputOps inputs, InputParams params, + const SoftmaxOptions *opts); + std::vector createReshape(InputOps inputs, InputParams params, + const ReshapeOptions *opts); + +private: + Graph *graph = nullptr; + + template + std::vector createOp(std::vector &inputs, + ActivationFunctionType activation, Types &&... args); +}; + +template +std::vector OpCreator::createOp(std::vector &inputs, + ActivationFunctionType activation, Types &&... args) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +} // namespace tflite +} // namespace frontend +} // namespace contrib +} // namespace nncc + +#endif // NNCC_TFLITE_OP_CREATOR_H diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp new file mode 100644 index 0000000..31dad85 --- /dev/null +++ b/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp @@ -0,0 +1,57 @@ +#include "tflite_op_creator.h" + +namespace nncc +{ +namespace contrib +{ +namespace frontend +{ +namespace tflite +{ + +std::vector OpCreator::createConv2D(InputOps inputs, InputParams params, + const Conv2DOptions *opts) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +std::vector OpCreator::createDepthConv2D(InputOps inputs, InputParams params, + const DepthwiseConv2DOptions *opts) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +std::vector OpCreator::createConcat(InputOps inputs, InputParams params, + const ConcatenationOptions *opts) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +std::vector OpCreator::createMaxPool(InputOps inputs, InputParams params, + const Pool2DOptions *opts) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +std::vector OpCreator::createAvgPool(InputOps inputs, InputParams params, + const Pool2DOptions *opts) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +std::vector OpCreator::createSoftmax(InputOps inputs, InputParams params, + const SoftmaxOptions *opts) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +std::vector OpCreator::createReshape(InputOps inputs, InputParams params, + const ReshapeOptions *opts) +{ + throw std::runtime_error{"Not yet implemented"}; +} + +} // namespace tflite +} // namespace frontend +} // namespace contrib +} // namespace nncc -- 2.7.4