From f93671ef9e458a3fe81185732074dc2eeccb07b7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=A1=D0=B5=D1=80=D0=B3=D0=B5=D0=B9=20=D0=91=D0=B0=D1=80?= =?utf8?q?=D0=B0=D0=BD=D0=BD=D0=B8=D0=BA=D0=BE=D0=B2/AI=20Tools=20Lab=20/S?= =?utf8?q?RR/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 10 Oct 2019 04:34:50 +0300 Subject: [PATCH] [neurun] Templatize BaseLoader (#8012) * Templatize `BaseLoader` class. * Add tflite-specific `LoaderDomain` structure. Signed-off-by: Sergei Barannikov --- runtimes/neurun/frontend/base_loader/base_loader.h | 281 ++++++++++++--------- runtimes/neurun/frontend/tflite/tflite_loader.cc | 30 ++- 2 files changed, 194 insertions(+), 117 deletions(-) diff --git a/runtimes/neurun/frontend/base_loader/base_loader.h b/runtimes/neurun/frontend/base_loader/base_loader.h index 95648e3..2eee5ba 100644 --- a/runtimes/neurun/frontend/base_loader/base_loader.h +++ b/runtimes/neurun/frontend/base_loader/base_loader.h @@ -29,8 +29,20 @@ namespace neurun namespace base_loader { -class BaseLoader -{ +template class BaseLoader +{ + using ActivationFunctionType = typename LoaderDomain::ActivationFunctionType; + using Buffer = typename LoaderDomain::Buffer; + using BuiltinOperator = typename LoaderDomain::BuiltinOperator; + using CustomOptionsFormat = typename LoaderDomain::CustomOptionsFormat; + using Model = typename LoaderDomain::Model; + using Operator = typename LoaderDomain::Operator; + using Padding = typename LoaderDomain::Padding; + using Pool2DOptions = typename LoaderDomain::Pool2DOptions; + using SubGraph = typename LoaderDomain::SubGraph; + using Tensor = typename LoaderDomain::Tensor; + using TensorType = typename LoaderDomain::TensorType; + public: /** * @brief Construct a new Loader object @@ -49,67 +61,67 @@ public: protected: ~BaseLoader() = default; -private: void loadModel(); // Helper functions - model::Activation convertActivation(tflite::ActivationFunctionType type); - model::DataType tensorTypeToDataType(tflite::TensorType type); + model::Activation convertActivation(ActivationFunctionType type); + model::DataType tensorTypeToDataType(TensorType type); // Load subgraphs - void loadSubgraph(const tflite::SubGraph *subgraph); + void loadSubgraph(const SubGraph *subgraph); // Load data from buffer to tensor on index - void loadConstantTensor(const tflite::Buffer *buffer, const uint32_t &index); + void loadConstantTensor(const Buffer *buffer, const uint32_t &index); // Create operands form tflite::Tensor - void loadOperand(const tflite::Tensor *tensor); - void loadOperationIO(const tflite::Operator *op, model::OperandIndexSequence &inputs, + void loadOperand(const Tensor *tensor); + void loadOperationIO(const Operator *op, model::OperandIndexSequence &inputs, model::OperandIndexSequence &outputs); - // Create operations from tflite::Operator - void loadOperation(const tflite::Operator *op); + // Create operations from Operator + void loadOperation(const Operator *op); // Load Strides and Paddings from options to param template void loadStridesAndPaddings(Param ¶m, const OptionsType *options); // Load Pool2D param - template void loadPool2D(Param ¶m, const tflite::Pool2DOptions *options); + template void loadPool2D(Param ¶m, const Pool2DOptions *options); // Operations - void loadConv2D(const tflite::Operator *op); - void loadDepthwiseConv2D(const tflite::Operator *op); - void loadTransposeConv(const tflite::Operator *op); - void loadAvgPool2D(const tflite::Operator *op); - void loadReshape(const tflite::Operator *op); - void loadSoftmax(const tflite::Operator *op); - void loadMaxPool2D(const tflite::Operator *op); - void loadConcatenation(const tflite::Operator *op); - void loadFC(const tflite::Operator *op); - void loadAdd(const tflite::Operator *op); - void loadSub(const tflite::Operator *op); - void loadMul(const tflite::Operator *op); - void loadDiv(const tflite::Operator *op); - void loadRelu(const tflite::Operator *op); - void loadRelu6(const tflite::Operator *op); - void loadRsqrt(const tflite::Operator *op); - void loadSqrt(const tflite::Operator *op); - void loadSquaredDifference(const tflite::Operator *op); - void loadTanh(const tflite::Operator *op); - void loadTranspose(const tflite::Operator *op); - void loadMean(const tflite::Operator *op); - void loadPad(const tflite::Operator *op); - void loadCustom(const tflite::Operator *op); - -private: + void loadConv2D(const Operator *op); + void loadDepthwiseConv2D(const Operator *op); + void loadTransposeConv(const Operator *op); + void loadAvgPool2D(const Operator *op); + void loadReshape(const Operator *op); + void loadSoftmax(const Operator *op); + void loadMaxPool2D(const Operator *op); + void loadConcatenation(const Operator *op); + void loadFC(const Operator *op); + void loadAdd(const Operator *op); + void loadSub(const Operator *op); + void loadMul(const Operator *op); + void loadDiv(const Operator *op); + void loadRelu(const Operator *op); + void loadRelu6(const Operator *op); + void loadRsqrt(const Operator *op); + void loadSqrt(const Operator *op); + void loadSquaredDifference(const Operator *op); + void loadTanh(const Operator *op); + void loadTranspose(const Operator *op); + void loadMean(const Operator *op); + void loadPad(const Operator *op); + void loadCustom(const Operator *op); + +protected: // Buffer for loading (if needed) std::vector _buffer; // Reference on loadable Graph graph::Graph &_graph; - // Mapping from tflite tensor index to Graph OperandIndex + // Mapping from tensor index to Graph OperandIndex std::map _tensor_to_operand; // Mapping from operator code to BuiltinOperator - std::vector _op_code_to_builtin_op; + std::vector _op_code_to_builtin_op; std::unordered_map _opcode_index_to_custom_opcode; }; -void BaseLoader::loadFromFile(const char *file_path) +template +void BaseLoader::BaseLoader::loadFromFile(const char *file_path) { std::ifstream stream(file_path, std::fstream::ate | std::fstream::binary); auto size = stream.tellg(); @@ -122,45 +134,50 @@ void BaseLoader::loadFromFile(const char *file_path) loadModel(); } -model::Activation BaseLoader::convertActivation(const tflite::ActivationFunctionType type) +template +model::Activation BaseLoader::BaseLoader::convertActivation( + const ActivationFunctionType type) { switch (type) { - case tflite::ActivationFunctionType::ActivationFunctionType_NONE: + case ActivationFunctionType::ActivationFunctionType_NONE: return model::Activation::NONE; - case tflite::ActivationFunctionType::ActivationFunctionType_RELU: + case ActivationFunctionType::ActivationFunctionType_RELU: return model::Activation::RELU; - case tflite::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1: + case ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1: return model::Activation::RELU1; - case tflite::ActivationFunctionType::ActivationFunctionType_RELU6: + case ActivationFunctionType::ActivationFunctionType_RELU6: return model::Activation::RELU6; - case tflite::ActivationFunctionType::ActivationFunctionType_TANH: + case ActivationFunctionType::ActivationFunctionType_TANH: return model::Activation::TANH; default: throw std::runtime_error(std::string("Unsupported activation type: ") - .append(tflite::EnumNameActivationFunctionType(type))); + .append(EnumNameActivationFunctionType(type))); } } -model::DataType BaseLoader::tensorTypeToDataType(const tflite::TensorType type) +template +model::DataType +BaseLoader::BaseLoader::tensorTypeToDataType(const TensorType type) { switch (type) { - case tflite::TensorType::TensorType_FLOAT32: + case TensorType::TensorType_FLOAT32: return model::DataType::FLOAT32; - case tflite::TensorType::TensorType_INT32: + case TensorType::TensorType_INT32: return model::DataType::INT32; - case tflite::TensorType::TensorType_BOOL: + case TensorType::TensorType_BOOL: return model::DataType::BOOL8; - case tflite::TensorType::TensorType_INT8: + case TensorType::TensorType_INT8: return model::DataType::QUANT8_ASYMM; default: throw std::runtime_error( - std::string("Unsupported tensor type: ").append(tflite::EnumNameTensorType(type))); + std::string("Unsupported tensor type: ").append(EnumNameTensorType(type))); } } -void BaseLoader::loadOperand(const tflite::Tensor *tensor) +template +void BaseLoader::loadOperand(const Tensor *tensor) { model::Shape shape; std::unique_ptr data_ptr; @@ -199,8 +216,10 @@ void BaseLoader::loadOperand(const tflite::Tensor *tensor) throw std::runtime_error("Variable tensor not supported!"); } -void BaseLoader::loadOperationIO(const tflite::Operator *op, model::OperandIndexSequence &inputs, - model::OperandIndexSequence &outputs) +template +void BaseLoader::loadOperationIO(const Operator *op, + model::OperandIndexSequence &inputs, + model::OperandIndexSequence &outputs) { for (const auto &idx : *op->inputs()) { @@ -213,8 +232,10 @@ void BaseLoader::loadOperationIO(const tflite::Operator *op, model::OperandIndex } } +template template -void BaseLoader::loadStridesAndPaddings(Param ¶m, const OptionsType *options) +void BaseLoader::loadStridesAndPaddings(Param ¶m, + const OptionsType *options) { model::Shape shape; model::TypeInfo type_info(model::DataType::INT32); @@ -222,15 +243,17 @@ void BaseLoader::loadStridesAndPaddings(Param ¶m, const OptionsType *options param.stride.vertical = options->stride_w(); param.stride.horizontal = options->stride_h(); // Paddings - if (options->padding() == tflite::Padding::Padding_SAME) + if (options->padding() == Padding::Padding_SAME) param.padding.type = model::PaddingType::SAME; - if (options->padding() == tflite::Padding::Padding_VALID) + if (options->padding() == Padding::Padding_VALID) param.padding.type = model::PaddingType::VALID; // param paddings indexes unused } +template template -void BaseLoader::loadPool2D(Param ¶m, const tflite::Pool2DOptions *options) +void BaseLoader::loadPool2D(Param ¶m, + const Pool2DOptions *options) { // Strides and Paddings loadStridesAndPaddings(param, options); @@ -244,7 +267,8 @@ void BaseLoader::loadPool2D(Param ¶m, const tflite::Pool2DOptions *options) param.activation = convertActivation(options->fused_activation_function()); } -void BaseLoader::loadConv2D(const tflite::Operator *op) +template +void BaseLoader::loadConv2D(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -261,7 +285,8 @@ void BaseLoader::loadConv2D(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadDepthwiseConv2D(const tflite::Operator *op) +template +void BaseLoader::loadDepthwiseConv2D(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -282,7 +307,8 @@ void BaseLoader::loadDepthwiseConv2D(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadTransposeConv(const tflite::Operator *op) +template +void BaseLoader::loadTransposeConv(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -297,7 +323,8 @@ void BaseLoader::loadTransposeConv(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadAvgPool2D(const tflite::Operator *op) +template +void BaseLoader::loadAvgPool2D(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -314,7 +341,8 @@ void BaseLoader::loadAvgPool2D(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadReshape(const tflite::Operator *op) +template +void BaseLoader::loadReshape(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -328,7 +356,8 @@ void BaseLoader::loadReshape(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadSoftmax(const tflite::Operator *op) +template +void BaseLoader::loadSoftmax(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -345,7 +374,8 @@ void BaseLoader::loadSoftmax(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadMaxPool2D(const tflite::Operator *op) +template +void BaseLoader::loadMaxPool2D(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -362,7 +392,8 @@ void BaseLoader::loadMaxPool2D(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadConcatenation(const tflite::Operator *op) +template +void BaseLoader::loadConcatenation(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -382,7 +413,8 @@ void BaseLoader::loadConcatenation(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadFC(const tflite::Operator *op) +template +void BaseLoader::loadFC(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -400,7 +432,8 @@ void BaseLoader::loadFC(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadAdd(const tflite::Operator *op) +template +void BaseLoader::loadAdd(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -416,7 +449,8 @@ void BaseLoader::loadAdd(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadSub(const tflite::Operator *op) +template +void BaseLoader::loadSub(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -432,7 +466,8 @@ void BaseLoader::loadSub(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadMul(const tflite::Operator *op) +template +void BaseLoader::loadMul(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -448,7 +483,8 @@ void BaseLoader::loadMul(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadDiv(const tflite::Operator *op) +template +void BaseLoader::loadDiv(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -464,7 +500,8 @@ void BaseLoader::loadDiv(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadRelu(const tflite::Operator *op) +template +void BaseLoader::loadRelu(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -475,7 +512,8 @@ void BaseLoader::loadRelu(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadRelu6(const tflite::Operator *op) +template +void BaseLoader::loadRelu6(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -486,7 +524,8 @@ void BaseLoader::loadRelu6(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadRsqrt(const tflite::Operator *op) +template +void BaseLoader::loadRsqrt(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -497,7 +536,8 @@ void BaseLoader::loadRsqrt(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadSqrt(const tflite::Operator *op) +template +void BaseLoader::loadSqrt(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -508,7 +548,8 @@ void BaseLoader::loadSqrt(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadSquaredDifference(const tflite::Operator *op) +template +void BaseLoader::loadSquaredDifference(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -520,7 +561,8 @@ void BaseLoader::loadSquaredDifference(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadTanh(const tflite::Operator *op) +template +void BaseLoader::loadTanh(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -531,7 +573,8 @@ void BaseLoader::loadTanh(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadTranspose(const tflite::Operator *op) +template +void BaseLoader::loadTranspose(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -553,7 +596,8 @@ void BaseLoader::loadTranspose(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadMean(const tflite::Operator *op) +template +void BaseLoader::loadMean(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -571,7 +615,8 @@ void BaseLoader::loadMean(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadPad(const tflite::Operator *op) +template +void BaseLoader::loadPad(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -582,7 +627,8 @@ void BaseLoader::loadPad(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadCustom(const tflite::Operator *op) +template +void BaseLoader::loadCustom(const Operator *op) { model::OperandIndexSequence inputs; model::OperandIndexSequence outputs; @@ -601,7 +647,7 @@ void BaseLoader::loadCustom(const tflite::Operator *op) auto constraint = model::operation::OperandConstraint::createExact(inputs.size()); - assert(op->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS && + assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS && "Unsupported custom operation options format"); size_t custom_op_data_size = op->custom_options()->size(); @@ -618,88 +664,90 @@ void BaseLoader::loadCustom(const tflite::Operator *op) _graph.addOperation(std::move(new_op)); } -void BaseLoader::loadOperation(const tflite::Operator *op) +template +void BaseLoader::loadOperation(const Operator *op) { const auto builtin_op = _op_code_to_builtin_op[op->opcode_index()]; switch (builtin_op) { - case tflite::BuiltinOperator_CONV_2D: + case BuiltinOperator::BuiltinOperator_CONV_2D: loadConv2D(op); return; - case tflite::BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D: loadAvgPool2D(op); return; - case tflite::BuiltinOperator_DEPTHWISE_CONV_2D: + case BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D: loadDepthwiseConv2D(op); return; - case tflite::BuiltinOperator_TRANSPOSE_CONV: + case BuiltinOperator::BuiltinOperator_TRANSPOSE_CONV: loadTransposeConv(op); return; - case tflite::BuiltinOperator_RESHAPE: + case BuiltinOperator::BuiltinOperator_RESHAPE: loadReshape(op); return; - case tflite::BuiltinOperator_SOFTMAX: + case BuiltinOperator::BuiltinOperator_SOFTMAX: loadSoftmax(op); return; - case tflite::BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator::BuiltinOperator_MAX_POOL_2D: loadMaxPool2D(op); return; - case tflite::BuiltinOperator_CONCATENATION: + case BuiltinOperator::BuiltinOperator_CONCATENATION: loadConcatenation(op); return; - case tflite::BuiltinOperator_FULLY_CONNECTED: + case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED: loadFC(op); return; - case tflite::BuiltinOperator_ADD: + case BuiltinOperator::BuiltinOperator_ADD: loadAdd(op); return; - case tflite::BuiltinOperator_SUB: + case BuiltinOperator::BuiltinOperator_SUB: loadSub(op); return; - case tflite::BuiltinOperator_MUL: + case BuiltinOperator::BuiltinOperator_MUL: loadMul(op); return; - case tflite::BuiltinOperator_DIV: + case BuiltinOperator::BuiltinOperator_DIV: loadDiv(op); return; - case tflite::BuiltinOperator_RELU: + case BuiltinOperator::BuiltinOperator_RELU: loadRelu(op); return; - case tflite::BuiltinOperator_RELU6: + case BuiltinOperator::BuiltinOperator_RELU6: loadRelu6(op); return; - case tflite::BuiltinOperator_RSQRT: + case BuiltinOperator::BuiltinOperator_RSQRT: loadRsqrt(op); return; - case tflite::BuiltinOperator_SQRT: + case BuiltinOperator::BuiltinOperator_SQRT: loadSqrt(op); return; - case tflite::BuiltinOperator_SQUARED_DIFFERENCE: + case BuiltinOperator::BuiltinOperator_SQUARED_DIFFERENCE: loadSquaredDifference(op); return; - case tflite::BuiltinOperator_TANH: + case BuiltinOperator::BuiltinOperator_TANH: loadTanh(op); return; - case tflite::BuiltinOperator_TRANSPOSE: + case BuiltinOperator::BuiltinOperator_TRANSPOSE: loadTranspose(op); return; - case tflite::BuiltinOperator_MEAN: + case BuiltinOperator::BuiltinOperator_MEAN: loadMean(op); return; - case tflite::BuiltinOperator_PAD: + case BuiltinOperator::BuiltinOperator_PAD: loadPad(op); return; - case tflite::BuiltinOperator_CUSTOM: + case BuiltinOperator::BuiltinOperator_CUSTOM: loadCustom(op); return; default: - throw std::runtime_error(std::string("Unsupported operation: ") - .append(tflite::EnumNameBuiltinOperator(builtin_op))); + throw std::runtime_error( + std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op))); } } -void BaseLoader::loadSubgraph(const tflite::SubGraph *subgraph) +template +void BaseLoader::loadSubgraph(const SubGraph *subgraph) { // Load tensors for (const auto *tensor : *subgraph->tensors()) @@ -724,7 +772,9 @@ void BaseLoader::loadSubgraph(const tflite::SubGraph *subgraph) // Name unused } -void BaseLoader::loadConstantTensor(const tflite::Buffer *buffer, const uint32_t &index) +template +void BaseLoader::loadConstantTensor(const Buffer *buffer, + const uint32_t &index) { const auto *data = buffer->data(); if (data != nullptr) @@ -736,9 +786,10 @@ void BaseLoader::loadConstantTensor(const tflite::Buffer *buffer, const uint32_t } } -void BaseLoader::loadModel() +template +void BaseLoader::loadModel() { - const auto *model = tflite::GetModel(_buffer.data()); + const auto *model = LoaderDomain::GetModel(_buffer.data()); // Version unused // const auto version = model->version(); const auto *op_codes = model->operator_codes(); @@ -753,7 +804,7 @@ void BaseLoader::loadModel() { _op_code_to_builtin_op.push_back(op_code->builtin_code()); - if (op_code->builtin_code() == tflite::BuiltinOperator_CUSTOM) + if (op_code->builtin_code() == BuiltinOperator::BuiltinOperator_CUSTOM) { auto id = op_code->custom_code()->str(); _opcode_index_to_custom_opcode[_op_code_to_builtin_op.size() - 1] = id; diff --git a/runtimes/neurun/frontend/tflite/tflite_loader.cc b/runtimes/neurun/frontend/tflite/tflite_loader.cc index c8b2bc5..26566d7 100644 --- a/runtimes/neurun/frontend/tflite/tflite_loader.cc +++ b/runtimes/neurun/frontend/tflite/tflite_loader.cc @@ -15,8 +15,8 @@ */ #include "tflite_loader.h" -#include "schema_generated.h" #include "base_loader.h" +#include "schema_generated.h" namespace neurun { @@ -26,7 +26,33 @@ namespace tflite_loader namespace { -class TFLiteLoader : private base_loader::BaseLoader +struct LoaderDomain +{ + using ActivationFunctionType = tflite::ActivationFunctionType; + using Buffer = tflite::Buffer; + using BuiltinOperator = tflite::BuiltinOperator; + using CustomOptionsFormat = tflite::CustomOptionsFormat; + using Model = tflite::Model; + using Operator = tflite::Operator; + using Padding = tflite::Padding; + using Pool2DOptions = tflite::Pool2DOptions; + using Tensor = tflite::Tensor; + using TensorType = tflite::TensorType; + using SubGraph = tflite::SubGraph; + + static const char *EnumNameBuiltinOperator(BuiltinOperator e) + { + return tflite::EnumNameBuiltinOperator(e); + } + static const char *EnumNameActivationFunctionType(ActivationFunctionType e) + { + return tflite::EnumNameActivationFunctionType(e); + } + static const char *EnumNameTensorType(TensorType e) { return tflite::EnumNameTensorType(e); } + static const Model *GetModel(const void *buf) { return tflite::GetModel(buf); } +}; + +class TFLiteLoader final : private base_loader::BaseLoader { public: using BaseLoader::BaseLoader; -- 2.7.4