2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <tfldump/Dump.h>
20 #include "OpPrinter.h"
24 #include <algorithm> // min
25 #include <iomanip> // setfill
30 void dump_buffer(std::ostream &os, const uint8_t *buffer, size_t size, size_t amount)
32 std::ios_base::fmtflags saveflags(os.flags());
35 bool ellipsis = amount > 0 && size > 4;
36 size_t count = ellipsis ? std::min(size, amount) : size;
38 for (size_t i = 0; i < count; i++)
45 os << std::showbase << std::setfill('0') << std::setw(2);
46 os << std::hex << (uint32_t)buffer[i];
58 void dump_vector(std::ostream &os, const std::vector<int32_t> &vs)
70 std::ostream &operator<<(std::ostream &os, const std::vector<int32_t> &vect)
72 tfldump::dump_vector(os, vect);
77 void dump_fbvect(std::ostream &os, const flatbuffers::Vector<T> *fbvect, uint32_t size)
79 for (uint32_t q = 0; q < size; q++)
88 void dump_fbvect(std::ostream &os, const flatbuffers::Vector<uint8_t> *fbvect, uint32_t size)
91 for (uint32_t q = 0; q < size; q++)
95 os << static_cast<uint32_t>(fbvect->Get(q));
100 std::ostream &operator<<(std::ostream &os, const flatbuffers::Vector<T> *fbvect)
102 if (fbvect == nullptr)
105 bool ellipsis = (fbvect->size() > 4);
106 auto limit_size = ellipsis ? 4 : fbvect->size();
110 os << "(" << fbvect->size() << ") ";
113 dump_fbvect(os, fbvect, limit_size);
123 void dump_sub_graph(std::ostream &os, tflread::Reader &reader)
125 auto tensors = reader.tensors();
126 auto operators = reader.operators();
128 // dump operands(tensors)
129 os << "Operands: T(subgraph index : tensor index) TYPE (shape) (shape_signature) "
130 << "B(buffer index) OperandName" << std::endl;
131 for (uint32_t i = 0; i < tensors->Length(); ++i)
133 // TODO refactor to some better structure
134 auto tensor = tensors->Get(i);
135 std::vector<int32_t> dims = {-1};
138 dims = tflread::as_index_vector(tensor->shape());
140 os << "T(" << reader.subgraph_index() << ":" << i << ") " << tflread::tensor_type(tensor)
142 os << "(" << dims << ") ";
143 if (tensor->shape_signature())
145 std::vector<int32_t> dims_sig = tflread::as_index_vector(tensor->shape_signature());
146 os << "(" << dims_sig << ") ";
148 os << "B(" << tensor->buffer() << ") ";
149 os << tflread::tensor_name(tensor) << std::endl;
151 if (auto q_params = tensor->quantization())
153 if ((q_params->min() && q_params->max()) || (q_params->scale() && q_params->zero_point()))
155 std::string strquantiz = " Quantization: ";
156 std::string strqindent(strquantiz.size(), ' ');
161 os << "min(" << q_params->min() << ") ";
162 if (q_params->min()->size() > 1)
163 os << std::endl << strqindent;
167 os << "max(" << q_params->max() << ") ";
168 if (q_params->max()->size() > 1)
169 os << std::endl << strqindent;
171 if (q_params->scale())
173 os << "scale(" << q_params->scale() << ") ";
174 if (q_params->scale()->size() > 1)
175 os << std::endl << strqindent;
177 if (q_params->zero_point())
179 os << "zeropt(" << q_params->zero_point() << ") ";
180 if (q_params->zero_point()->size() > 1)
181 os << std::endl << strqindent;
183 os << "quantized_dimension(" << q_params->quantized_dimension() << ")";
189 if (const auto &s_params = tensor->sparsity())
191 std::string strsparsity = " Sparsity: ";
192 std::string strsindent(strsparsity.size(), ' ');
195 if (s_params->traversal_order())
197 os << "traversal_order(" << s_params->traversal_order() << ") ";
198 os << std::endl << strsindent;
200 if (s_params->block_map())
202 os << "block_map(" << s_params->block_map() << ") ";
203 os << std::endl << strsindent;
205 if (const auto &dim_metadata = s_params->dim_metadata())
208 for (const auto &dm : *dim_metadata)
210 std::string strdm = "dim_metadata[" + std::to_string(idx++) + "]: ";
211 std::string strdm_indent = strsindent + std::string(strdm.size(), ' ');
214 os << "format(" << tflite::EnumNameDimensionType(dm->format()) << ") ";
215 os << std::endl << strdm_indent;
217 os << "dense_size(" << dm->dense_size() << ") ";
218 os << std::endl << strdm_indent;
220 os << "array_segments_type("
221 << tflite::EnumNameSparseIndexVector(dm->array_segments_type()) << ") ";
222 os << std::endl << strdm_indent;
224 os << "array_segments(";
225 switch (dm->array_segments_type())
227 case tflite::SparseIndexVector_NONE:
230 case tflite::SparseIndexVector_Int32Vector:
231 os << dm->array_segments_as_Int32Vector()->values();
233 case tflite::SparseIndexVector_Uint16Vector:
234 os << dm->array_segments_as_Uint16Vector()->values();
236 case tflite::SparseIndexVector_Uint8Vector:
237 os << dm->array_segments_as_Uint8Vector()->values();
240 throw std::runtime_error("Invalid SparseIndexVector type of array_segments");
242 os << ")" << std::endl << strdm_indent;
244 os << "array_indices_type(" << tflite::EnumNameSparseIndexVector(dm->array_indices_type())
246 os << std::endl << strdm_indent;
248 os << "array_indices(";
249 switch (dm->array_indices_type())
251 case tflite::SparseIndexVector_NONE:
254 case tflite::SparseIndexVector_Int32Vector:
255 os << dm->array_indices_as_Int32Vector()->values();
257 case tflite::SparseIndexVector_Uint16Vector:
258 os << dm->array_indices_as_Uint16Vector()->values();
260 case tflite::SparseIndexVector_Uint8Vector:
261 os << dm->array_indices_as_Uint8Vector()->values();
264 throw std::runtime_error("Invalid SparseIndexVector type of array_indices");
266 os << ")" << std::endl << strsindent;
274 os << "Operators: O(subgraph index : operator index) OpCodeName " << std::endl;
275 os << " Option(values) ... <-- depending on OpCode" << std::endl;
276 os << " I T(tensor index) OperandName <-- as input" << std::endl;
277 os << " O T(tensor index) OperandName <-- as output" << std::endl;
278 for (uint32_t i = 0; i < operators->Length(); ++i)
280 const auto op = operators->Get(i);
281 tflite::BuiltinOperator builtincode = reader.builtin_code(op);
283 const std::vector<int32_t> &inputs = tflread::as_index_vector(op->inputs());
284 const std::vector<int32_t> &outputs = tflread::as_index_vector(op->outputs());
285 auto op_name = reader.opcode_name(op);
287 os << "O(" << reader.subgraph_index() << ":" << i << ") " << op_name << " ";
290 if (auto op_prn = OpPrinterRegistry::get().lookup(builtincode))
292 op_prn->options(op, os);
295 for (auto input : inputs)
297 os << " I T(" << reader.subgraph_index() << ":" << input << ") ";
300 auto tensor = tensors->Get(input);
301 os << tflread::tensor_name(tensor);
305 for (auto output : outputs)
307 os << " O T(" << reader.subgraph_index() << ":" << output << ") ";
310 auto tensor = tensors->Get(output);
311 os << tflread::tensor_name(tensor);
318 // dump network inputs/outputs
319 os << "Inputs/Outputs: I(input)/O(output) T(tensor index) OperandName" << std::endl;
321 for (const auto input : reader.inputs())
323 auto tensor = tensors->Get(input);
324 std::string name = tflread::tensor_name(tensor);
325 os << "I T(" << reader.subgraph_index() << ":" << input << ") " << name << std::endl;
328 for (const auto output : reader.outputs())
330 auto tensor = tensors->Get(output);
331 std::string name = tflread::tensor_name(tensor);
332 os << "O T(" << reader.subgraph_index() << ":" << output << ") " << name << std::endl;
338 void dump_model(std::ostream &os, const tflite::Model *model)
340 tflread::Reader reader(model);
342 uint32_t num_subgraph = reader.num_subgraph();
344 // dump model version
345 os << "===================================================================" << std::endl;
346 os << "Model version: " << reader.version() << std::endl;
347 os << " # sub graphs: " << num_subgraph << std::endl;
350 auto opcodes = reader.opcodes();
351 auto buffers = reader.buffers();
353 // dump operator_codes
354 os << "Operator Codes: [order] OpCodeName (OpCode Enum)" << std::endl;
355 int32_t opcode_index = 0;
356 for (auto opcode : opcodes)
358 tflite::BuiltinOperator op_code = opcode->builtin_code();
359 auto op_name = tflread::opcode_name(opcode);
360 auto op_version = opcode->version();
362 os << "[" << opcode_index << "] " << op_name << " (code: " << op_code
363 << ", version: " << op_version << ")" << std::endl;
370 os << "Buffers: B(index) (length) values, if any" << std::endl;
371 for (uint32_t i = 0; i < buffers->Length(); ++i)
373 const uint8_t *buff_data;
374 size_t size = reader.buffer_info(i, &buff_data);
376 os << "B(" << i << ") (" << size << ") ";
377 if (buff_data != nullptr)
379 dump_buffer(os, buff_data, size, 16);
385 for (uint32_t sg = 0; sg < num_subgraph; ++sg)
387 reader.select_subgraph(sg);
389 os << "-------------------------------------------------------------------" << std::endl;
390 os << "Sub-Graph: #" << sg << " " << reader.subgraph_name() << std::endl;
393 dump_sub_graph(os, reader);
396 os << "===================================================================" << std::endl;
399 } // namespace tfldump
401 std::ostream &operator<<(std::ostream &os, const tflite::Model *model)
403 tfldump::dump_model(os, model);