2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "TFliteImport.h"
21 #include <mio_tflite2121/Helper.h>
28 TFliteImport::TFliteImport(const tflite::Model *model)
30 _subgraphs = model->subgraphs();
31 _buffers = model->buffers();
33 auto opcodes = model->operator_codes();
34 for (const ::tflite::OperatorCode *opcode : *opcodes)
36 _op_codes.push_back(opcode);
40 bool TFliteImport::select_sub_graph(uint32_t sgindex)
47 if (_subgraphs->Length() <= sgindex)
53 const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
55 _tensors = subgraph->tensors();
56 _operators = subgraph->operators();
58 _inputs = as_index_vector(subgraph->inputs());
59 _outputs = as_index_vector(subgraph->outputs());
64 tflite::BuiltinOperator TFliteImport::builtin_code(const tflite::Operator *op) const
66 uint32_t index = op->opcode_index();
67 assert(index < _op_codes.size());
68 const tflite::OperatorCode *opcode = _op_codes.at(index);
70 return mio::tflite::builtin_code_neutral(opcode);
73 std::string TFliteImport::opcode_name(const tflite::Operator *op) const
75 uint32_t index = op->opcode_index();
76 assert(index < _op_codes.size());
77 const tflite::OperatorCode *opcode = _op_codes.at(index);
79 if (!mio::tflite::is_valid(opcode))
81 std::ostringstream oss;
82 oss << "(invalid: " << index << ")";
86 if (mio::tflite::is_custom(opcode))
88 if (!opcode->custom_code())
89 return "(invalid custom)";
91 return opcode->custom_code()->c_str();
94 tflite::BuiltinOperator code = mio::tflite::builtin_code_neutral(opcode);
95 return EnumNameBuiltinOperator(code);
98 size_t TFliteImport::buffer_info(const tflite::Tensor *tensor, const uint8_t **buff_data)
100 *buff_data = nullptr;
102 if (tensor->buffer() == 0)
105 if (auto *buffer = (*_buffers)[tensor->buffer()])
107 if (auto *array = buffer->data())
109 if (size_t size = array->size())
111 *buff_data = reinterpret_cast<const uint8_t *>(array->data());
120 } // namespace tflchef