[neurun] Templatize BaseLoader (#8012)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Thu, 10 Oct 2019 01:34:50 +0000 (04:34 +0300)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 10 Oct 2019 01:34:50 +0000 (10:34 +0900)
* Templatize `BaseLoader` class.
* Add tflite-specific `LoaderDomain` structure.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
runtimes/neurun/frontend/base_loader/base_loader.h
runtimes/neurun/frontend/tflite/tflite_loader.cc

index 95648e3..2eee5ba 100644 (file)
@@ -29,8 +29,20 @@ namespace neurun
 namespace base_loader
 {
 
-class BaseLoader
-{
+template <typename LoaderDomain, typename SpecificLoader> class BaseLoader
+{
+  using ActivationFunctionType = typename LoaderDomain::ActivationFunctionType;
+  using Buffer = typename LoaderDomain::Buffer;
+  using BuiltinOperator = typename LoaderDomain::BuiltinOperator;
+  using CustomOptionsFormat = typename LoaderDomain::CustomOptionsFormat;
+  using Model = typename LoaderDomain::Model;
+  using Operator = typename LoaderDomain::Operator;
+  using Padding = typename LoaderDomain::Padding;
+  using Pool2DOptions = typename LoaderDomain::Pool2DOptions;
+  using SubGraph = typename LoaderDomain::SubGraph;
+  using Tensor = typename LoaderDomain::Tensor;
+  using TensorType = typename LoaderDomain::TensorType;
+
 public:
   /**
    * @brief Construct a new Loader object
@@ -49,67 +61,67 @@ public:
 protected:
   ~BaseLoader() = default;
 
-private:
   void loadModel();
 
   // Helper functions
-  model::Activation convertActivation(tflite::ActivationFunctionType type);
-  model::DataType tensorTypeToDataType(tflite::TensorType type);
+  model::Activation convertActivation(ActivationFunctionType type);
+  model::DataType tensorTypeToDataType(TensorType type);
 
   // Load subgraphs
-  void loadSubgraph(const tflite::SubGraph *subgraph);
+  void loadSubgraph(const SubGraph *subgraph);
   // Load data from buffer to tensor on index
-  void loadConstantTensor(const tflite::Buffer *buffer, const uint32_t &index);
+  void loadConstantTensor(const Buffer *buffer, const uint32_t &index);
   // Create operands form tflite::Tensor
-  void loadOperand(const tflite::Tensor *tensor);
-  void loadOperationIO(const tflite::Operator *op, model::OperandIndexSequence &inputs,
+  void loadOperand(const Tensor *tensor);
+  void loadOperationIO(const Operator *op, model::OperandIndexSequence &inputs,
                        model::OperandIndexSequence &outputs);
-  // Create operations from tflite::Operator
-  void loadOperation(const tflite::Operator *op);
+  // Create operations from Operator
+  void loadOperation(const Operator *op);
   // Load Strides and Paddings from options to param
   template <typename Param, typename OptionsType>
   void loadStridesAndPaddings(Param &param, const OptionsType *options);
   // Load Pool2D param
-  template <typename Param> void loadPool2D(Param &param, const tflite::Pool2DOptions *options);
+  template <typename Param> void loadPool2D(Param &param, const Pool2DOptions *options);
 
   // Operations
-  void loadConv2D(const tflite::Operator *op);
-  void loadDepthwiseConv2D(const tflite::Operator *op);
-  void loadTransposeConv(const tflite::Operator *op);
-  void loadAvgPool2D(const tflite::Operator *op);
-  void loadReshape(const tflite::Operator *op);
-  void loadSoftmax(const tflite::Operator *op);
-  void loadMaxPool2D(const tflite::Operator *op);
-  void loadConcatenation(const tflite::Operator *op);
-  void loadFC(const tflite::Operator *op);
-  void loadAdd(const tflite::Operator *op);
-  void loadSub(const tflite::Operator *op);
-  void loadMul(const tflite::Operator *op);
-  void loadDiv(const tflite::Operator *op);
-  void loadRelu(const tflite::Operator *op);
-  void loadRelu6(const tflite::Operator *op);
-  void loadRsqrt(const tflite::Operator *op);
-  void loadSqrt(const tflite::Operator *op);
-  void loadSquaredDifference(const tflite::Operator *op);
-  void loadTanh(const tflite::Operator *op);
-  void loadTranspose(const tflite::Operator *op);
-  void loadMean(const tflite::Operator *op);
-  void loadPad(const tflite::Operator *op);
-  void loadCustom(const tflite::Operator *op);
-
-private:
+  void loadConv2D(const Operator *op);
+  void loadDepthwiseConv2D(const Operator *op);
+  void loadTransposeConv(const Operator *op);
+  void loadAvgPool2D(const Operator *op);
+  void loadReshape(const Operator *op);
+  void loadSoftmax(const Operator *op);
+  void loadMaxPool2D(const Operator *op);
+  void loadConcatenation(const Operator *op);
+  void loadFC(const Operator *op);
+  void loadAdd(const Operator *op);
+  void loadSub(const Operator *op);
+  void loadMul(const Operator *op);
+  void loadDiv(const Operator *op);
+  void loadRelu(const Operator *op);
+  void loadRelu6(const Operator *op);
+  void loadRsqrt(const Operator *op);
+  void loadSqrt(const Operator *op);
+  void loadSquaredDifference(const Operator *op);
+  void loadTanh(const Operator *op);
+  void loadTranspose(const Operator *op);
+  void loadMean(const Operator *op);
+  void loadPad(const Operator *op);
+  void loadCustom(const Operator *op);
+
+protected:
   // Buffer for loading (if needed)
   std::vector<char> _buffer;
   // Reference on loadable Graph
   graph::Graph &_graph;
-  // Mapping from tflite tensor index to Graph OperandIndex
+  // Mapping from tensor index to Graph OperandIndex
   std::map<uint32_t, model::OperandIndex> _tensor_to_operand;
   // Mapping from operator code to BuiltinOperator
-  std::vector<tflite::BuiltinOperator> _op_code_to_builtin_op;
+  std::vector<BuiltinOperator> _op_code_to_builtin_op;
   std::unordered_map<uint32_t, std::string> _opcode_index_to_custom_opcode;
 };
 
-void BaseLoader::loadFromFile(const char *file_path)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::loadFromFile(const char *file_path)
 {
   std::ifstream stream(file_path, std::fstream::ate | std::fstream::binary);
   auto size = stream.tellg();
@@ -122,45 +134,50 @@ void BaseLoader::loadFromFile(const char *file_path)
   loadModel();
 }
 
-model::Activation BaseLoader::convertActivation(const tflite::ActivationFunctionType type)
+template <typename LoaderDomain, typename SpecificLoader>
+model::Activation BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::convertActivation(
+    const ActivationFunctionType type)
 {
   switch (type)
   {
-    case tflite::ActivationFunctionType::ActivationFunctionType_NONE:
+    case ActivationFunctionType::ActivationFunctionType_NONE:
       return model::Activation::NONE;
-    case tflite::ActivationFunctionType::ActivationFunctionType_RELU:
+    case ActivationFunctionType::ActivationFunctionType_RELU:
       return model::Activation::RELU;
-    case tflite::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
+    case ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
       return model::Activation::RELU1;
-    case tflite::ActivationFunctionType::ActivationFunctionType_RELU6:
+    case ActivationFunctionType::ActivationFunctionType_RELU6:
       return model::Activation::RELU6;
-    case tflite::ActivationFunctionType::ActivationFunctionType_TANH:
+    case ActivationFunctionType::ActivationFunctionType_TANH:
       return model::Activation::TANH;
     default:
       throw std::runtime_error(std::string("Unsupported activation type: ")
-                                   .append(tflite::EnumNameActivationFunctionType(type)));
+                                   .append(EnumNameActivationFunctionType(type)));
   }
 }
 
-model::DataType BaseLoader::tensorTypeToDataType(const tflite::TensorType type)
+template <typename LoaderDomain, typename SpecificLoader>
+model::DataType
+BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::tensorTypeToDataType(const TensorType type)
 {
   switch (type)
   {
-    case tflite::TensorType::TensorType_FLOAT32:
+    case TensorType::TensorType_FLOAT32:
       return model::DataType::FLOAT32;
-    case tflite::TensorType::TensorType_INT32:
+    case TensorType::TensorType_INT32:
       return model::DataType::INT32;
-    case tflite::TensorType::TensorType_BOOL:
+    case TensorType::TensorType_BOOL:
       return model::DataType::BOOL8;
-    case tflite::TensorType::TensorType_INT8:
+    case TensorType::TensorType_INT8:
       return model::DataType::QUANT8_ASYMM;
     default:
       throw std::runtime_error(
-          std::string("Unsupported tensor type: ").append(tflite::EnumNameTensorType(type)));
+          std::string("Unsupported tensor type: ").append(EnumNameTensorType(type)));
   }
 }
 
-void BaseLoader::loadOperand(const tflite::Tensor *tensor)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadOperand(const Tensor *tensor)
 {
   model::Shape shape;
   std::unique_ptr<model::Data> data_ptr;
@@ -199,8 +216,10 @@ void BaseLoader::loadOperand(const tflite::Tensor *tensor)
     throw std::runtime_error("Variable tensor not supported!");
 }
 
-void BaseLoader::loadOperationIO(const tflite::Operator *op, model::OperandIndexSequence &inputs,
-                                 model::OperandIndexSequence &outputs)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadOperationIO(const Operator *op,
+                                                               model::OperandIndexSequence &inputs,
+                                                               model::OperandIndexSequence &outputs)
 {
   for (const auto &idx : *op->inputs())
   {
@@ -213,8 +232,10 @@ void BaseLoader::loadOperationIO(const tflite::Operator *op, model::OperandIndex
   }
 }
 
+template <typename LoaderDomain, typename SpecificLoader>
 template <typename Param, typename OptionsType>
-void BaseLoader::loadStridesAndPaddings(Param &param, const OptionsType *options)
+void BaseLoader<LoaderDomain, SpecificLoader>::loadStridesAndPaddings(Param &param,
+                                                                      const OptionsType *options)
 {
   model::Shape shape;
   model::TypeInfo type_info(model::DataType::INT32);
@@ -222,15 +243,17 @@ void BaseLoader::loadStridesAndPaddings(Param &param, const OptionsType *options
   param.stride.vertical = options->stride_w();
   param.stride.horizontal = options->stride_h();
   // Paddings
-  if (options->padding() == tflite::Padding::Padding_SAME)
+  if (options->padding() == Padding::Padding_SAME)
     param.padding.type = model::PaddingType::SAME;
-  if (options->padding() == tflite::Padding::Padding_VALID)
+  if (options->padding() == Padding::Padding_VALID)
     param.padding.type = model::PaddingType::VALID;
   // param paddings indexes unused
 }
 
+template <typename LoaderDomain, typename SpecificLoader>
 template <typename Param>
-void BaseLoader::loadPool2D(Param &param, const tflite::Pool2DOptions *options)
+void BaseLoader<LoaderDomain, SpecificLoader>::loadPool2D(Param &param,
+                                                          const Pool2DOptions *options)
 {
   // Strides and Paddings
   loadStridesAndPaddings(param, options);
@@ -244,7 +267,8 @@ void BaseLoader::loadPool2D(Param &param, const tflite::Pool2DOptions *options)
   param.activation = convertActivation(options->fused_activation_function());
 }
 
-void BaseLoader::loadConv2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadConv2D(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -261,7 +285,8 @@ void BaseLoader::loadConv2D(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadDepthwiseConv2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadDepthwiseConv2D(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -282,7 +307,8 @@ void BaseLoader::loadDepthwiseConv2D(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadTransposeConv(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadTransposeConv(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -297,7 +323,8 @@ void BaseLoader::loadTransposeConv(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadAvgPool2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadAvgPool2D(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -314,7 +341,8 @@ void BaseLoader::loadAvgPool2D(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadReshape(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadReshape(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -328,7 +356,8 @@ void BaseLoader::loadReshape(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadSoftmax(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSoftmax(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -345,7 +374,8 @@ void BaseLoader::loadSoftmax(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadMaxPool2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadMaxPool2D(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -362,7 +392,8 @@ void BaseLoader::loadMaxPool2D(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadConcatenation(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadConcatenation(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -382,7 +413,8 @@ void BaseLoader::loadConcatenation(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadFC(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadFC(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -400,7 +432,8 @@ void BaseLoader::loadFC(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadAdd(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadAdd(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -416,7 +449,8 @@ void BaseLoader::loadAdd(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadSub(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSub(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -432,7 +466,8 @@ void BaseLoader::loadSub(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadMul(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadMul(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -448,7 +483,8 @@ void BaseLoader::loadMul(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadDiv(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadDiv(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -464,7 +500,8 @@ void BaseLoader::loadDiv(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadRelu(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadRelu(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -475,7 +512,8 @@ void BaseLoader::loadRelu(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadRelu6(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadRelu6(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -486,7 +524,8 @@ void BaseLoader::loadRelu6(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadRsqrt(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadRsqrt(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -497,7 +536,8 @@ void BaseLoader::loadRsqrt(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadSqrt(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSqrt(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -508,7 +548,8 @@ void BaseLoader::loadSqrt(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadSquaredDifference(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSquaredDifference(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -520,7 +561,8 @@ void BaseLoader::loadSquaredDifference(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadTanh(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadTanh(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -531,7 +573,8 @@ void BaseLoader::loadTanh(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadTranspose(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadTranspose(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -553,7 +596,8 @@ void BaseLoader::loadTranspose(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadMean(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadMean(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -571,7 +615,8 @@ void BaseLoader::loadMean(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadPad(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadPad(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -582,7 +627,8 @@ void BaseLoader::loadPad(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadCustom(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadCustom(const Operator *op)
 {
   model::OperandIndexSequence inputs;
   model::OperandIndexSequence outputs;
@@ -601,7 +647,7 @@ void BaseLoader::loadCustom(const tflite::Operator *op)
 
   auto constraint = model::operation::OperandConstraint::createExact(inputs.size());
 
-  assert(op->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS &&
+  assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS &&
          "Unsupported custom operation options format");
 
   size_t custom_op_data_size = op->custom_options()->size();
@@ -618,88 +664,90 @@ void BaseLoader::loadCustom(const tflite::Operator *op)
   _graph.addOperation(std::move(new_op));
 }
 
-void BaseLoader::loadOperation(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadOperation(const Operator *op)
 {
   const auto builtin_op = _op_code_to_builtin_op[op->opcode_index()];
 
   switch (builtin_op)
   {
-    case tflite::BuiltinOperator_CONV_2D:
+    case BuiltinOperator::BuiltinOperator_CONV_2D:
       loadConv2D(op);
       return;
-    case tflite::BuiltinOperator_AVERAGE_POOL_2D:
+    case BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D:
       loadAvgPool2D(op);
       return;
-    case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
+    case BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D:
       loadDepthwiseConv2D(op);
       return;
-    case tflite::BuiltinOperator_TRANSPOSE_CONV:
+    case BuiltinOperator::BuiltinOperator_TRANSPOSE_CONV:
       loadTransposeConv(op);
       return;
-    case tflite::BuiltinOperator_RESHAPE:
+    case BuiltinOperator::BuiltinOperator_RESHAPE:
       loadReshape(op);
       return;
-    case tflite::BuiltinOperator_SOFTMAX:
+    case BuiltinOperator::BuiltinOperator_SOFTMAX:
       loadSoftmax(op);
       return;
-    case tflite::BuiltinOperator_MAX_POOL_2D:
+    case BuiltinOperator::BuiltinOperator_MAX_POOL_2D:
       loadMaxPool2D(op);
       return;
-    case tflite::BuiltinOperator_CONCATENATION:
+    case BuiltinOperator::BuiltinOperator_CONCATENATION:
       loadConcatenation(op);
       return;
-    case tflite::BuiltinOperator_FULLY_CONNECTED:
+    case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
       loadFC(op);
       return;
-    case tflite::BuiltinOperator_ADD:
+    case BuiltinOperator::BuiltinOperator_ADD:
       loadAdd(op);
       return;
-    case tflite::BuiltinOperator_SUB:
+    case BuiltinOperator::BuiltinOperator_SUB:
       loadSub(op);
       return;
-    case tflite::BuiltinOperator_MUL:
+    case BuiltinOperator::BuiltinOperator_MUL:
       loadMul(op);
       return;
-    case tflite::BuiltinOperator_DIV:
+    case BuiltinOperator::BuiltinOperator_DIV:
       loadDiv(op);
       return;
-    case tflite::BuiltinOperator_RELU:
+    case BuiltinOperator::BuiltinOperator_RELU:
       loadRelu(op);
       return;
-    case tflite::BuiltinOperator_RELU6:
+    case BuiltinOperator::BuiltinOperator_RELU6:
       loadRelu6(op);
       return;
-    case tflite::BuiltinOperator_RSQRT:
+    case BuiltinOperator::BuiltinOperator_RSQRT:
       loadRsqrt(op);
       return;
-    case tflite::BuiltinOperator_SQRT:
+    case BuiltinOperator::BuiltinOperator_SQRT:
       loadSqrt(op);
       return;
-    case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
+    case BuiltinOperator::BuiltinOperator_SQUARED_DIFFERENCE:
       loadSquaredDifference(op);
       return;
-    case tflite::BuiltinOperator_TANH:
+    case BuiltinOperator::BuiltinOperator_TANH:
       loadTanh(op);
       return;
-    case tflite::BuiltinOperator_TRANSPOSE:
+    case BuiltinOperator::BuiltinOperator_TRANSPOSE:
       loadTranspose(op);
       return;
-    case tflite::BuiltinOperator_MEAN:
+    case BuiltinOperator::BuiltinOperator_MEAN:
       loadMean(op);
       return;
-    case tflite::BuiltinOperator_PAD:
+    case BuiltinOperator::BuiltinOperator_PAD:
       loadPad(op);
       return;
-    case tflite::BuiltinOperator_CUSTOM:
+    case BuiltinOperator::BuiltinOperator_CUSTOM:
       loadCustom(op);
       return;
     default:
-      throw std::runtime_error(std::string("Unsupported operation: ")
-                                   .append(tflite::EnumNameBuiltinOperator(builtin_op)));
+      throw std::runtime_error(
+          std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
   }
 }
 
-void BaseLoader::loadSubgraph(const tflite::SubGraph *subgraph)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSubgraph(const SubGraph *subgraph)
 {
   // Load tensors
   for (const auto *tensor : *subgraph->tensors())
@@ -724,7 +772,9 @@ void BaseLoader::loadSubgraph(const tflite::SubGraph *subgraph)
   // Name unused
 }
 
-void BaseLoader::loadConstantTensor(const tflite::Buffer *buffer, const uint32_t &index)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadConstantTensor(const Buffer *buffer,
+                                                                  const uint32_t &index)
 {
   const auto *data = buffer->data();
   if (data != nullptr)
@@ -736,9 +786,10 @@ void BaseLoader::loadConstantTensor(const tflite::Buffer *buffer, const uint32_t
   }
 }
 
-void BaseLoader::loadModel()
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadModel()
 {
-  const auto *model = tflite::GetModel(_buffer.data());
+  const auto *model = LoaderDomain::GetModel(_buffer.data());
   // Version unused
   // const auto version = model->version();
   const auto *op_codes = model->operator_codes();
@@ -753,7 +804,7 @@ void BaseLoader::loadModel()
   {
     _op_code_to_builtin_op.push_back(op_code->builtin_code());
 
-    if (op_code->builtin_code() == tflite::BuiltinOperator_CUSTOM)
+    if (op_code->builtin_code() == BuiltinOperator::BuiltinOperator_CUSTOM)
     {
       auto id = op_code->custom_code()->str();
       _opcode_index_to_custom_opcode[_op_code_to_builtin_op.size() - 1] = id;
index c8b2bc5..26566d7 100644 (file)
@@ -15,8 +15,8 @@
  */
 
 #include "tflite_loader.h"
-#include "schema_generated.h"
 #include "base_loader.h"
+#include "schema_generated.h"
 
 namespace neurun
 {
@@ -26,7 +26,33 @@ namespace tflite_loader
 namespace
 {
 
-class TFLiteLoader : private base_loader::BaseLoader
+struct LoaderDomain
+{
+  using ActivationFunctionType = tflite::ActivationFunctionType;
+  using Buffer = tflite::Buffer;
+  using BuiltinOperator = tflite::BuiltinOperator;
+  using CustomOptionsFormat = tflite::CustomOptionsFormat;
+  using Model = tflite::Model;
+  using Operator = tflite::Operator;
+  using Padding = tflite::Padding;
+  using Pool2DOptions = tflite::Pool2DOptions;
+  using Tensor = tflite::Tensor;
+  using TensorType = tflite::TensorType;
+  using SubGraph = tflite::SubGraph;
+
+  static const char *EnumNameBuiltinOperator(BuiltinOperator e)
+  {
+    return tflite::EnumNameBuiltinOperator(e);
+  }
+  static const char *EnumNameActivationFunctionType(ActivationFunctionType e)
+  {
+    return tflite::EnumNameActivationFunctionType(e);
+  }
+  static const char *EnumNameTensorType(TensorType e) { return tflite::EnumNameTensorType(e); }
+  static const Model *GetModel(const void *buf) { return tflite::GetModel(buf); }
+};
+
+class TFLiteLoader final : private base_loader::BaseLoader<LoaderDomain, TFLiteLoader>
 {
 public:
   using BaseLoader::BaseLoader;