Add Model IR operator creator skeleton for TFLite (#354)
authorDmitry Mozolev/AI Tools Lab /SRR/Engineer/삼성전자 <d.mozolev@samsung.com>
Wed, 27 Jun 2018 07:19:44 +0000 (10:19 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Wed, 27 Jun 2018 07:19:44 +0000 (16:19 +0900)
Add Model IR operator creator skeleton for TFLite

This is a Model IR operator creator class skeleton.
Contains public interface for creating supported operations.

Signed-off-by: Dmitry Mozolev <d.mozolev@samsung.com>
contrib/nnc/libs/frontend/tflite/CMakeLists.txt
contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h [new file with mode: 0644]
contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp [new file with mode: 0644]

index b8d0e55..c4c6e40 100644 (file)
@@ -15,6 +15,7 @@ FlatBuffers_Generate(FB_GEN
 set(tflite_importer_sources ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_walker.cpp
                             ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_dump_visitor.cpp
                             ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_ir_visitor.cpp
+                            ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_op_creator.cpp
                             ${CMAKE_CURRENT_SOURCE_DIR}/src/tflite_v3_importer.cpp)
 file(GLOB tflite_importer_headers include/*.h)
 list(APPEND tflite_importer_headers ${FB_GEN_SOURCES})
diff --git a/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h b/contrib/nnc/libs/frontend/tflite/include/tflite_op_creator.h
new file mode 100644 (file)
index 0000000..6b3e3a7
--- /dev/null
@@ -0,0 +1,79 @@
+#ifndef NNCC_TFLITE_OP_CREATOR_H
+#define NNCC_TFLITE_OP_CREATOR_H
+
+#include <map>
+#include <vector>
+#include <memory>
+#include <cstdint>
+#include <stdexcept>
+
+#include "PluginException.h"
+#include "nnc/core/IR/model/graph/graph.h"
+#include "nnc/core/IR/model/graph/ir_node.h"
+#include "nnc/core/linalg/TensorVariant.h"
+#include "nncc/core/ADT/tensor/Shape.h"
+
+#include "schema_v3.h"
+#include "shape_helper.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace frontend
+{
+namespace tflite
+{
+
+using namespace v3_tflite;
+
+namespace ops = nncc::contrib::core::IR::model::ops;
+using nncc::contrib::core::IR::model::Graph;
+using nncc::contrib::core::IR::model::ADT::INode;
+using IrTensor = nncc::contrib::core::ADT::TensorVariant;
+using nncc::core::ADT::tensor::Shape;
+
+class OpCreator
+{
+public:
+  using InputOps = std::vector<INode::Ref> &;
+  using InputParams = std::vector<std::shared_ptr<IrTensor>> &;
+
+  explicit OpCreator(Graph *g) : graph(g) {};
+
+  std::vector<INode::Ref> createConv2D(InputOps inputs, InputParams params,
+                                       const Conv2DOptions *opts);
+  std::vector<INode::Ref> createDepthConv2D(InputOps inputs, InputParams params,
+                                            const DepthwiseConv2DOptions *opts);
+  std::vector<INode::Ref> createConcat(InputOps inputs, InputParams params,
+                                       const ConcatenationOptions *opts);
+  std::vector<INode::Ref> createMaxPool(InputOps inputs, InputParams params,
+                                        const Pool2DOptions *opts);
+  std::vector<INode::Ref> createAvgPool(InputOps inputs, InputParams params,
+                                        const Pool2DOptions *opts);
+  std::vector<INode::Ref> createSoftmax(InputOps inputs, InputParams params,
+                                        const SoftmaxOptions *opts);
+  std::vector<INode::Ref> createReshape(InputOps inputs, InputParams params,
+                                        const ReshapeOptions *opts);
+
+private:
+  Graph *graph = nullptr;
+
+  template <typename OpType, typename... Types>
+  std::vector<INode::Ref> createOp(std::vector<INode::Ref> &inputs,
+                                   ActivationFunctionType activation, Types &&... args);
+};
+
+template <typename OpType, typename... Types>
+std::vector<INode::Ref> OpCreator::createOp(std::vector<INode::Ref> &inputs,
+                                            ActivationFunctionType activation, Types &&... args)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+} // namespace tflite
+} // namespace frontend
+} // namespace contrib
+} // namespace nncc
+
+#endif // NNCC_TFLITE_OP_CREATOR_H
diff --git a/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp b/contrib/nnc/libs/frontend/tflite/src/tflite_op_creator.cpp
new file mode 100644 (file)
index 0000000..31dad85
--- /dev/null
@@ -0,0 +1,57 @@
+#include "tflite_op_creator.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace frontend
+{
+namespace tflite
+{
+
+std::vector<INode::Ref> OpCreator::createConv2D(InputOps inputs, InputParams params,
+                                                const Conv2DOptions *opts)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+std::vector<INode::Ref> OpCreator::createDepthConv2D(InputOps inputs, InputParams params,
+                                                     const DepthwiseConv2DOptions *opts)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+std::vector<INode::Ref> OpCreator::createConcat(InputOps inputs, InputParams params,
+                                                const ConcatenationOptions *opts)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+std::vector<INode::Ref> OpCreator::createMaxPool(InputOps inputs, InputParams params,
+                                                 const Pool2DOptions *opts)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+std::vector<INode::Ref> OpCreator::createAvgPool(InputOps inputs, InputParams params,
+                                                 const Pool2DOptions *opts)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+std::vector<INode::Ref> OpCreator::createSoftmax(InputOps inputs, InputParams params,
+                                                 const SoftmaxOptions *opts)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+std::vector<INode::Ref> OpCreator::createReshape(InputOps inputs, InputParams params,
+                                                 const ReshapeOptions *opts)
+{
+  throw std::runtime_error{"Not yet implemented"};
+}
+
+} // namespace tflite
+} // namespace frontend
+} // namespace contrib
+} // namespace nncc