loadModel();
}
-void Loader::loadOperand(const Tensor * /*tensor*/) { throw std::runtime_error("NYI"); }
+model::DataType tensorTypeToDataType(const TensorType &type)
+{
+ switch (type)
+ {
+ case TensorType::TensorType_FLOAT32:
+ return model::DataType::FLOAT32;
+ case TensorType::TensorType_INT32:
+ return model::DataType::INT32;
+ case TensorType::TensorType_BOOL:
+ return model::DataType::BOOL8;
+ case TensorType::TensorType_INT8:
+ return model::DataType::QUANT8_ASYMM;
+ default:
+ throw std::runtime_error("NYI");
+ }
+}
+
+void Loader::loadOperand(const tflite::Tensor *tensor)
+{
+ model::Shape shape;
+ std::unique_ptr<model::Data> data_ptr;
+ // Shape
+ const auto *tensor_shape = tensor->shape();
+ for (const auto &dim : *tensor_shape)
+ {
+ shape.append(dim);
+ }
+ // Type
+ model::DataType data_type = tensorTypeToDataType(tensor->type());
+ // Create TypeInfo
+ model::TypeInfo type_info(data_type);
+ // Create operand
+ const auto &operand_index = _graph.addOperand(shape, type_info);
+ // Buffer index
+ if (tensor->buffer() != 0)
+ _tensor_to_operand[tensor->buffer()] = operand_index;
+ // Name unused
+ // auto name = tensor->name();
+ // Quantization
+ auto quantization = tensor->quantization();
+ auto scale = quantization->scale();
+ auto zero_point = quantization->zero_point();
+ if (scale != NULL || zero_point != NULL)
+ throw std::runtime_error("Quantization is not supported!");
-void Loader::loadOperation(const Operator * /*op*/) { throw std::runtime_error("NYI"); }
+ auto details = quantization->details_as_CustomQuantization();
+ if (details != NULL)
+ throw std::runtime_error("Custom Quantization is not supported");
+ // Variablie
+ if (tensor->is_variable())
+ throw std::runtime_error("Variable tensor not supported!");
+}
-void Loader::loadSubgraph(const SubGraph * /*subgraph*/) { throw std::runtime_error("NYI"); }
+void Loader::loadConv2D(const tflite::Operator * /*op*/) { throw std::runtime_error("NYI"); }
-void Loader::loadConstantTensor(const Buffer * /*buffer*/, const uint32_t & /*index*/)
+void Loader::loadDepthwiseConv2D(const tflite::Operator * /*op*/)
{
throw std::runtime_error("NYI");
}
-void Loader::loadModel() { throw std::runtime_error("NYI"); }
+void Loader::loadAvgPool2D(const tflite::Operator * /*op*/) { throw std::runtime_error("NYI"); }
+
+void Loader::loadReshape(const tflite::Operator * /*op*/) { throw std::runtime_error("NYI"); }
+
+void Loader::loadSoftmax(const tflite::Operator * /*op*/) { throw std::runtime_error("NYI"); }
+
+void Loader::loadMaxPool2D(const tflite::Operator * /*op*/) { throw std::runtime_error("NYI"); }
+
+void Loader::loadConcatenation(const tflite::Operator * /*op*/) { throw std::runtime_error("NYI"); }
+
+void Loader::loadFC(const tflite::Operator * /*op*/) { throw std::runtime_error("NYI"); }
+
+void Loader::loadOperation(const tflite::Operator *op)
+{
+ switch (_op_code_to_builtin_op[op->opcode_index()])
+ {
+ case BuiltinOperator_CONV_2D:
+ loadConv2D(op);
+ return;
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ loadAvgPool2D(op);
+ return;
+ case BuiltinOperator_DEPTHWISE_CONV_2D:
+ loadDepthwiseConv2D(op);
+ return;
+ case BuiltinOperator_RESHAPE:
+ loadReshape(op);
+ return;
+ case BuiltinOperator_SOFTMAX:
+ loadSoftmax(op);
+ return;
+ case BuiltinOperator_MAX_POOL_2D:
+ loadMaxPool2D(op);
+ return;
+ case BuiltinOperator_CONCATENATION:
+ loadConcatenation(op);
+ return;
+ case BuiltinOperator_FULLY_CONNECTED:
+ loadFC(op);
+ return;
+ default:
+ auto *names = EnumNamesBuiltinOperator();
+ int enum_value = static_cast<int>(_op_code_to_builtin_op[op->opcode_index()]);
+ throw std::runtime_error(std::string("Unsupported operation: ").append(names[enum_value]));
+ }
+}
+
+void Loader::loadSubgraph(const SubGraph *subgraph)
+{
+ // Load tensors
+ for (const auto *tensor : *subgraph->tensors())
+ {
+ loadOperand(tensor);
+ }
+ // Set inputs
+ for (const auto &input_ind : *subgraph->inputs())
+ {
+ _graph.addInput(model::OperandIndex(input_ind));
+ }
+ // Set outputs
+ for (const auto &output_ind : *subgraph->outputs())
+ {
+ _graph.addOutput(model::OperandIndex(output_ind));
+ }
+ // Create operations
+ for (const auto *op : *subgraph->operators())
+ {
+ loadOperation(op);
+ }
+ // Name unused
+}
+
+void Loader::loadConstantTensor(const Buffer *buffer, const uint32_t &index)
+{
+ const auto *data = buffer->data();
+ if (data != nullptr)
+ {
+ auto ptr = nnfw::cpp14::make_unique<model::CachedData>(data->data(), data->size());
+ const auto &operand_index = _tensor_to_operand[index];
+ auto &operand = _graph.operands().at(operand_index);
+ operand.data(std::move(ptr));
+ }
+}
+
+void Loader::loadModel()
+{
+ const auto *model = GetModel(_buffer.data());
+ // Version unused
+ // const auto version = model->version();
+ const auto *op_codes = model->operator_codes();
+ const auto *subgraphs = model->subgraphs();
+ // Description unused
+ // const auto *description = model->description();
+ const auto *buffers = model->buffers();
+ // Metabuffer unsued
+ // const auto *metadata_buffer = model->metadata_buffer();
+ // Use operator codes
+ for (const auto *op_code : *op_codes)
+ {
+ _op_code_to_builtin_op.push_back(op_code->builtin_code());
+ // Custom code unsued
+ // Version unused
+ }
+ // Load subgraphs
+ for (const auto *subgraph : *subgraphs)
+ {
+ loadSubgraph(subgraph);
+ }
+ // Load buffers with constant tensors
+ for (uint32_t ind = 0; ind < buffers->size(); ind++)
+ {
+ loadConstantTensor(buffers->Get(ind), ind);
+ }
+
+ _graph.finishBuilding();
+}
} // namespace tflite_loader