namespace base_loader
{
-class BaseLoader
-{
+template <typename LoaderDomain, typename SpecificLoader> class BaseLoader
+{
+ using ActivationFunctionType = typename LoaderDomain::ActivationFunctionType;
+ using Buffer = typename LoaderDomain::Buffer;
+ using BuiltinOperator = typename LoaderDomain::BuiltinOperator;
+ using CustomOptionsFormat = typename LoaderDomain::CustomOptionsFormat;
+ using Model = typename LoaderDomain::Model;
+ using Operator = typename LoaderDomain::Operator;
+ using Padding = typename LoaderDomain::Padding;
+ using Pool2DOptions = typename LoaderDomain::Pool2DOptions;
+ using SubGraph = typename LoaderDomain::SubGraph;
+ using Tensor = typename LoaderDomain::Tensor;
+ using TensorType = typename LoaderDomain::TensorType;
+
public:
/**
* @brief Construct a new Loader object
protected:
~BaseLoader() = default;
-private:
void loadModel();
// Helper functions
- model::Activation convertActivation(tflite::ActivationFunctionType type);
- model::DataType tensorTypeToDataType(tflite::TensorType type);
+ model::Activation convertActivation(ActivationFunctionType type);
+ model::DataType tensorTypeToDataType(TensorType type);
// Load subgraphs
- void loadSubgraph(const tflite::SubGraph *subgraph);
+ void loadSubgraph(const SubGraph *subgraph);
// Load data from buffer to tensor on index
- void loadConstantTensor(const tflite::Buffer *buffer, const uint32_t &index);
+ void loadConstantTensor(const Buffer *buffer, const uint32_t &index);
// Create operands form tflite::Tensor
- void loadOperand(const tflite::Tensor *tensor);
- void loadOperationIO(const tflite::Operator *op, model::OperandIndexSequence &inputs,
+ void loadOperand(const Tensor *tensor);
+ void loadOperationIO(const Operator *op, model::OperandIndexSequence &inputs,
model::OperandIndexSequence &outputs);
- // Create operations from tflite::Operator
- void loadOperation(const tflite::Operator *op);
+ // Create operations from Operator
+ void loadOperation(const Operator *op);
// Load Strides and Paddings from options to param
template <typename Param, typename OptionsType>
void loadStridesAndPaddings(Param ¶m, const OptionsType *options);
// Load Pool2D param
- template <typename Param> void loadPool2D(Param ¶m, const tflite::Pool2DOptions *options);
+ template <typename Param> void loadPool2D(Param ¶m, const Pool2DOptions *options);
// Operations
- void loadConv2D(const tflite::Operator *op);
- void loadDepthwiseConv2D(const tflite::Operator *op);
- void loadTransposeConv(const tflite::Operator *op);
- void loadAvgPool2D(const tflite::Operator *op);
- void loadReshape(const tflite::Operator *op);
- void loadSoftmax(const tflite::Operator *op);
- void loadMaxPool2D(const tflite::Operator *op);
- void loadConcatenation(const tflite::Operator *op);
- void loadFC(const tflite::Operator *op);
- void loadAdd(const tflite::Operator *op);
- void loadSub(const tflite::Operator *op);
- void loadMul(const tflite::Operator *op);
- void loadDiv(const tflite::Operator *op);
- void loadRelu(const tflite::Operator *op);
- void loadRelu6(const tflite::Operator *op);
- void loadRsqrt(const tflite::Operator *op);
- void loadSqrt(const tflite::Operator *op);
- void loadSquaredDifference(const tflite::Operator *op);
- void loadTanh(const tflite::Operator *op);
- void loadTranspose(const tflite::Operator *op);
- void loadMean(const tflite::Operator *op);
- void loadPad(const tflite::Operator *op);
- void loadCustom(const tflite::Operator *op);
-
-private:
+ void loadConv2D(const Operator *op);
+ void loadDepthwiseConv2D(const Operator *op);
+ void loadTransposeConv(const Operator *op);
+ void loadAvgPool2D(const Operator *op);
+ void loadReshape(const Operator *op);
+ void loadSoftmax(const Operator *op);
+ void loadMaxPool2D(const Operator *op);
+ void loadConcatenation(const Operator *op);
+ void loadFC(const Operator *op);
+ void loadAdd(const Operator *op);
+ void loadSub(const Operator *op);
+ void loadMul(const Operator *op);
+ void loadDiv(const Operator *op);
+ void loadRelu(const Operator *op);
+ void loadRelu6(const Operator *op);
+ void loadRsqrt(const Operator *op);
+ void loadSqrt(const Operator *op);
+ void loadSquaredDifference(const Operator *op);
+ void loadTanh(const Operator *op);
+ void loadTranspose(const Operator *op);
+ void loadMean(const Operator *op);
+ void loadPad(const Operator *op);
+ void loadCustom(const Operator *op);
+
+protected:
// Buffer for loading (if needed)
std::vector<char> _buffer;
// Reference on loadable Graph
graph::Graph &_graph;
- // Mapping from tflite tensor index to Graph OperandIndex
+ // Mapping from tensor index to Graph OperandIndex
std::map<uint32_t, model::OperandIndex> _tensor_to_operand;
// Mapping from operator code to BuiltinOperator
- std::vector<tflite::BuiltinOperator> _op_code_to_builtin_op;
+ std::vector<BuiltinOperator> _op_code_to_builtin_op;
std::unordered_map<uint32_t, std::string> _opcode_index_to_custom_opcode;
};
-void BaseLoader::loadFromFile(const char *file_path)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::loadFromFile(const char *file_path)
{
std::ifstream stream(file_path, std::fstream::ate | std::fstream::binary);
auto size = stream.tellg();
loadModel();
}
-model::Activation BaseLoader::convertActivation(const tflite::ActivationFunctionType type)
+template <typename LoaderDomain, typename SpecificLoader>
+model::Activation BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::convertActivation(
+ const ActivationFunctionType type)
{
switch (type)
{
- case tflite::ActivationFunctionType::ActivationFunctionType_NONE:
+ case ActivationFunctionType::ActivationFunctionType_NONE:
return model::Activation::NONE;
- case tflite::ActivationFunctionType::ActivationFunctionType_RELU:
+ case ActivationFunctionType::ActivationFunctionType_RELU:
return model::Activation::RELU;
- case tflite::ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
+ case ActivationFunctionType::ActivationFunctionType_RELU_N1_TO_1:
return model::Activation::RELU1;
- case tflite::ActivationFunctionType::ActivationFunctionType_RELU6:
+ case ActivationFunctionType::ActivationFunctionType_RELU6:
return model::Activation::RELU6;
- case tflite::ActivationFunctionType::ActivationFunctionType_TANH:
+ case ActivationFunctionType::ActivationFunctionType_TANH:
return model::Activation::TANH;
default:
throw std::runtime_error(std::string("Unsupported activation type: ")
- .append(tflite::EnumNameActivationFunctionType(type)));
+ .append(EnumNameActivationFunctionType(type)));
}
}
-model::DataType BaseLoader::tensorTypeToDataType(const tflite::TensorType type)
+template <typename LoaderDomain, typename SpecificLoader>
+model::DataType
+BaseLoader<LoaderDomain, SpecificLoader>::BaseLoader::tensorTypeToDataType(const TensorType type)
{
switch (type)
{
- case tflite::TensorType::TensorType_FLOAT32:
+ case TensorType::TensorType_FLOAT32:
return model::DataType::FLOAT32;
- case tflite::TensorType::TensorType_INT32:
+ case TensorType::TensorType_INT32:
return model::DataType::INT32;
- case tflite::TensorType::TensorType_BOOL:
+ case TensorType::TensorType_BOOL:
return model::DataType::BOOL8;
- case tflite::TensorType::TensorType_INT8:
+ case TensorType::TensorType_INT8:
return model::DataType::QUANT8_ASYMM;
default:
throw std::runtime_error(
- std::string("Unsupported tensor type: ").append(tflite::EnumNameTensorType(type)));
+ std::string("Unsupported tensor type: ").append(EnumNameTensorType(type)));
}
}
-void BaseLoader::loadOperand(const tflite::Tensor *tensor)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadOperand(const Tensor *tensor)
{
model::Shape shape;
std::unique_ptr<model::Data> data_ptr;
throw std::runtime_error("Variable tensor not supported!");
}
-void BaseLoader::loadOperationIO(const tflite::Operator *op, model::OperandIndexSequence &inputs,
- model::OperandIndexSequence &outputs)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadOperationIO(const Operator *op,
+ model::OperandIndexSequence &inputs,
+ model::OperandIndexSequence &outputs)
{
for (const auto &idx : *op->inputs())
{
}
}
+template <typename LoaderDomain, typename SpecificLoader>
template <typename Param, typename OptionsType>
-void BaseLoader::loadStridesAndPaddings(Param ¶m, const OptionsType *options)
+void BaseLoader<LoaderDomain, SpecificLoader>::loadStridesAndPaddings(Param ¶m,
+ const OptionsType *options)
{
model::Shape shape;
model::TypeInfo type_info(model::DataType::INT32);
param.stride.vertical = options->stride_w();
param.stride.horizontal = options->stride_h();
// Paddings
- if (options->padding() == tflite::Padding::Padding_SAME)
+ if (options->padding() == Padding::Padding_SAME)
param.padding.type = model::PaddingType::SAME;
- if (options->padding() == tflite::Padding::Padding_VALID)
+ if (options->padding() == Padding::Padding_VALID)
param.padding.type = model::PaddingType::VALID;
// param paddings indexes unused
}
+template <typename LoaderDomain, typename SpecificLoader>
template <typename Param>
-void BaseLoader::loadPool2D(Param ¶m, const tflite::Pool2DOptions *options)
+void BaseLoader<LoaderDomain, SpecificLoader>::loadPool2D(Param ¶m,
+ const Pool2DOptions *options)
{
// Strides and Paddings
loadStridesAndPaddings(param, options);
param.activation = convertActivation(options->fused_activation_function());
}
-void BaseLoader::loadConv2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadConv2D(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadDepthwiseConv2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadDepthwiseConv2D(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadTransposeConv(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadTransposeConv(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadAvgPool2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadAvgPool2D(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadReshape(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadReshape(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadSoftmax(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSoftmax(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadMaxPool2D(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadMaxPool2D(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadConcatenation(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadConcatenation(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadFC(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadFC(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadAdd(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadAdd(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadSub(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSub(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadMul(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadMul(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadDiv(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadDiv(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadRelu(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadRelu(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadRelu6(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadRelu6(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadRsqrt(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadRsqrt(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadSqrt(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSqrt(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadSquaredDifference(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSquaredDifference(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadTanh(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadTanh(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadTranspose(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadTranspose(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadMean(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadMean(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadPad(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadPad(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadCustom(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadCustom(const Operator *op)
{
model::OperandIndexSequence inputs;
model::OperandIndexSequence outputs;
auto constraint = model::operation::OperandConstraint::createExact(inputs.size());
- assert(op->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS &&
+ assert(op->custom_options_format() == CustomOptionsFormat::CustomOptionsFormat_FLEXBUFFERS &&
"Unsupported custom operation options format");
size_t custom_op_data_size = op->custom_options()->size();
_graph.addOperation(std::move(new_op));
}
-void BaseLoader::loadOperation(const tflite::Operator *op)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadOperation(const Operator *op)
{
const auto builtin_op = _op_code_to_builtin_op[op->opcode_index()];
switch (builtin_op)
{
- case tflite::BuiltinOperator_CONV_2D:
+ case BuiltinOperator::BuiltinOperator_CONV_2D:
loadConv2D(op);
return;
- case tflite::BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator::BuiltinOperator_AVERAGE_POOL_2D:
loadAvgPool2D(op);
return;
- case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
+ case BuiltinOperator::BuiltinOperator_DEPTHWISE_CONV_2D:
loadDepthwiseConv2D(op);
return;
- case tflite::BuiltinOperator_TRANSPOSE_CONV:
+ case BuiltinOperator::BuiltinOperator_TRANSPOSE_CONV:
loadTransposeConv(op);
return;
- case tflite::BuiltinOperator_RESHAPE:
+ case BuiltinOperator::BuiltinOperator_RESHAPE:
loadReshape(op);
return;
- case tflite::BuiltinOperator_SOFTMAX:
+ case BuiltinOperator::BuiltinOperator_SOFTMAX:
loadSoftmax(op);
return;
- case tflite::BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator::BuiltinOperator_MAX_POOL_2D:
loadMaxPool2D(op);
return;
- case tflite::BuiltinOperator_CONCATENATION:
+ case BuiltinOperator::BuiltinOperator_CONCATENATION:
loadConcatenation(op);
return;
- case tflite::BuiltinOperator_FULLY_CONNECTED:
+ case BuiltinOperator::BuiltinOperator_FULLY_CONNECTED:
loadFC(op);
return;
- case tflite::BuiltinOperator_ADD:
+ case BuiltinOperator::BuiltinOperator_ADD:
loadAdd(op);
return;
- case tflite::BuiltinOperator_SUB:
+ case BuiltinOperator::BuiltinOperator_SUB:
loadSub(op);
return;
- case tflite::BuiltinOperator_MUL:
+ case BuiltinOperator::BuiltinOperator_MUL:
loadMul(op);
return;
- case tflite::BuiltinOperator_DIV:
+ case BuiltinOperator::BuiltinOperator_DIV:
loadDiv(op);
return;
- case tflite::BuiltinOperator_RELU:
+ case BuiltinOperator::BuiltinOperator_RELU:
loadRelu(op);
return;
- case tflite::BuiltinOperator_RELU6:
+ case BuiltinOperator::BuiltinOperator_RELU6:
loadRelu6(op);
return;
- case tflite::BuiltinOperator_RSQRT:
+ case BuiltinOperator::BuiltinOperator_RSQRT:
loadRsqrt(op);
return;
- case tflite::BuiltinOperator_SQRT:
+ case BuiltinOperator::BuiltinOperator_SQRT:
loadSqrt(op);
return;
- case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
+ case BuiltinOperator::BuiltinOperator_SQUARED_DIFFERENCE:
loadSquaredDifference(op);
return;
- case tflite::BuiltinOperator_TANH:
+ case BuiltinOperator::BuiltinOperator_TANH:
loadTanh(op);
return;
- case tflite::BuiltinOperator_TRANSPOSE:
+ case BuiltinOperator::BuiltinOperator_TRANSPOSE:
loadTranspose(op);
return;
- case tflite::BuiltinOperator_MEAN:
+ case BuiltinOperator::BuiltinOperator_MEAN:
loadMean(op);
return;
- case tflite::BuiltinOperator_PAD:
+ case BuiltinOperator::BuiltinOperator_PAD:
loadPad(op);
return;
- case tflite::BuiltinOperator_CUSTOM:
+ case BuiltinOperator::BuiltinOperator_CUSTOM:
loadCustom(op);
return;
default:
- throw std::runtime_error(std::string("Unsupported operation: ")
- .append(tflite::EnumNameBuiltinOperator(builtin_op)));
+ throw std::runtime_error(
+ std::string("Unsupported operation: ").append(EnumNameBuiltinOperator(builtin_op)));
}
}
-void BaseLoader::loadSubgraph(const tflite::SubGraph *subgraph)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadSubgraph(const SubGraph *subgraph)
{
// Load tensors
for (const auto *tensor : *subgraph->tensors())
// Name unused
}
-void BaseLoader::loadConstantTensor(const tflite::Buffer *buffer, const uint32_t &index)
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadConstantTensor(const Buffer *buffer,
+ const uint32_t &index)
{
const auto *data = buffer->data();
if (data != nullptr)
}
}
-void BaseLoader::loadModel()
+template <typename LoaderDomain, typename SpecificLoader>
+void BaseLoader<LoaderDomain, SpecificLoader>::loadModel()
{
- const auto *model = tflite::GetModel(_buffer.data());
+ const auto *model = LoaderDomain::GetModel(_buffer.data());
// Version unused
// const auto version = model->version();
const auto *op_codes = model->operator_codes();
{
_op_code_to_builtin_op.push_back(op_code->builtin_code());
- if (op_code->builtin_code() == tflite::BuiltinOperator_CUSTOM)
+ if (op_code->builtin_code() == BuiltinOperator::BuiltinOperator_CUSTOM)
{
auto id = op_code->custom_code()->str();
_opcode_index_to_custom_opcode[_op_code_to_builtin_op.size() - 1] = id;