2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
26 // This will provide v3/v3a format neutral BuiltinOperator
27 tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
29 assert(opcode != nullptr);
30 int8_t dp_code = opcode->deprecated_builtin_code();
31 // 127 is max of int8_t which is upper bound of v3 builtin_code
32 // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
33 if (dp_code < 127 && dp_code >= 0)
34 return tflite::BuiltinOperator(dp_code);
35 return opcode->builtin_code();
38 bool is_valid(const tflite::OperatorCode *opcode)
40 tflite::BuiltinOperator code = builtin_code_neutral(opcode);
41 return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
44 bool is_custom(const tflite::OperatorCode *opcode)
46 tflite::BuiltinOperator code = builtin_code_neutral(opcode);
47 return (code == tflite::BuiltinOperator_CUSTOM);
50 std::string opcode_name(const tflite::OperatorCode *opcode)
54 if (!is_valid(opcode))
56 std::ostringstream oss;
61 if (is_custom(opcode))
63 if (!opcode->custom_code())
64 return "(invalid custom)";
66 std::string custom_op = "CUSTOM(";
67 custom_op += opcode->custom_code()->c_str();
72 tflite::BuiltinOperator code = builtin_code_neutral(opcode);
73 return tflite::EnumNameBuiltinOperator(code);
76 const char *tensor_type(const tflite::Tensor *tensor)
78 return tflite::EnumNameTensorType(tensor->type());
81 const char *tensor_name(const tflite::Tensor *tensor)
83 static const char *kEmptyTensorName = "(noname)";
85 auto name = tensor->name();
89 return kEmptyTensorName;
92 Reader::Reader(const tflite::Model *model)
94 _subgraphs = model->subgraphs();
95 _buffers = model->buffers();
97 auto opcodes = model->operator_codes();
98 for (const ::tflite::OperatorCode *opcode : *opcodes)
100 _op_codes.push_back(opcode);
104 size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data)
106 if (buff_data != nullptr)
108 *buff_data = nullptr;
114 if (auto *buffer = (*_buffers)[buf_idx])
116 if (auto *array = buffer->data())
118 if (size_t size = array->size())
120 if (buff_data != nullptr)
122 *buff_data = reinterpret_cast<const uint8_t *>(array->data());
132 tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
134 uint32_t index = op->opcode_index();
135 assert(index < _op_codes.size());
136 const tflite::OperatorCode *opcode = _op_codes.at(index);
138 return tflinspect::builtin_code_neutral(opcode);
141 std::string Reader::opcode_name(const tflite::Operator *op) const
143 uint32_t index = op->opcode_index();
144 assert(index < _op_codes.size());
145 const tflite::OperatorCode *opcode = _op_codes.at(index);
147 if (!is_valid(opcode))
149 std::ostringstream oss;
150 oss << "(invalid: " << index << ")";
154 return tflinspect::opcode_name(opcode);
157 bool Reader::select_subgraph(uint32_t sgindex)
160 _operators = nullptr;
165 if (_subgraphs->Length() <= sgindex)
171 const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
173 _tensors = subgraph->tensors();
174 _operators = subgraph->operators();
176 _inputs = as_index_vector(subgraph->inputs());
177 _outputs = as_index_vector(subgraph->outputs());
182 } // namespace tflinspect