2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "TFliteImport.h"
26 const char *kEmptyTensorName = "(noname)";
28 const char *tensor_type(const tflite::Tensor *tensor)
30 return tflite::EnumNameTensorType(tensor->type());
33 const char *tensor_name(const tflite::Tensor *tensor)
35 auto name = tensor->name();
38 return kEmptyTensorName;
41 // This will provide v3/v3a format neutral BuiltinOperator
42 tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
44 assert(opcode != nullptr);
45 int8_t dp_code = opcode->deprecated_builtin_code();
46 // 127 is max of int8_t which is upper bound of v3 builtin_code
47 // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
48 if (dp_code < 127 && dp_code >= 0)
49 return tflite::BuiltinOperator(dp_code);
50 return opcode->builtin_code();
53 bool is_valid(const tflite::OperatorCode *opcode)
55 tflite::BuiltinOperator code = builtin_code_neutral(opcode);
56 return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
59 bool is_custom(const tflite::OperatorCode *opcode)
61 tflite::BuiltinOperator code = builtin_code_neutral(opcode);
62 return (code == tflite::BuiltinOperator_CUSTOM);
65 TFliteImport::TFliteImport(const tflite::Model *model)
67 _subgraphs = model->subgraphs();
68 _buffers = model->buffers();
70 auto opcodes = model->operator_codes();
71 for (const ::tflite::OperatorCode *opcode : *opcodes)
73 _op_codes.push_back(opcode);
77 bool TFliteImport::select_sub_graph(uint32_t sgindex)
84 if (_subgraphs->Length() <= sgindex)
90 const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
92 _tensors = subgraph->tensors();
93 _operators = subgraph->operators();
95 _inputs = as_index_vector(subgraph->inputs());
96 _outputs = as_index_vector(subgraph->outputs());
101 tflite::BuiltinOperator TFliteImport::builtin_code(const tflite::Operator *op) const
103 uint32_t index = op->opcode_index();
104 assert(index < _op_codes.size());
105 const tflite::OperatorCode *opcode = _op_codes.at(index);
107 return builtin_code_neutral(opcode);
110 std::string TFliteImport::opcode_name(const tflite::Operator *op) const
112 uint32_t index = op->opcode_index();
113 assert(index < _op_codes.size());
114 const tflite::OperatorCode *opcode = _op_codes.at(index);
116 if (!is_valid(opcode))
118 std::ostringstream oss;
119 oss << "(invalid: " << index << ")";
123 if (is_custom(opcode))
125 if (!opcode->custom_code())
126 return "(invalid custom)";
128 return opcode->custom_code()->c_str();
131 tflite::BuiltinOperator code = builtin_code_neutral(opcode);
132 return EnumNameBuiltinOperator(code);
135 size_t TFliteImport::buffer_info(const tflite::Tensor *tensor, const uint8_t **buff_data)
137 *buff_data = nullptr;
139 if (tensor->buffer() == 0)
142 if (auto *buffer = (*_buffers)[tensor->buffer()])
144 if (auto *array = buffer->data())
146 if (size_t size = array->size())
148 *buff_data = reinterpret_cast<const uint8_t *>(array->data());
157 } // namespace tflchef