454e3a8a16a01bc730fabcee5a2674aa4007219d
[platform/core/ml/nnfw.git] / compiler / tfldump / src / Read.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Read.h"
18
19 #include <mio_tflite280/Helper.h>
20
21 #include <sstream>
22 #include <string>
23
24 namespace tflread
25 {
26
27 Reader::Reader(const tflite::Model *model)
28 {
29   _version = model->version();
30   _subgraphs = model->subgraphs();
31   _buffers = model->buffers();
32   _metadata = model->metadata();
33   _signaturedefs = model->signature_defs();
34
35   auto opcodes = model->operator_codes();
36   for (const ::tflite::OperatorCode *opcode : *opcodes)
37   {
38     _op_codes.push_back(opcode);
39   }
40 }
41
42 size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data)
43 {
44   *buff_data = nullptr;
45
46   if (buf_idx == 0)
47     return 0;
48
49   if (auto *buffer = (*_buffers)[buf_idx])
50   {
51     if (auto *array = buffer->data())
52     {
53       if (size_t size = array->size())
54       {
55         *buff_data = reinterpret_cast<const uint8_t *>(array->data());
56         return size;
57       }
58     }
59   }
60
61   return 0;
62 }
63
64 tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
65 {
66   uint32_t index = op->opcode_index();
67   assert(index < _op_codes.size());
68   const tflite::OperatorCode *opcode = _op_codes.at(index);
69
70   return mio::tflite::builtin_code_neutral(opcode);
71 }
72
73 std::string Reader::opcode_name(const tflite::Operator *op) const
74 {
75   uint32_t index = op->opcode_index();
76   assert(index < _op_codes.size());
77   const tflite::OperatorCode *opcode = _op_codes.at(index);
78
79   if (!mio::tflite::is_valid(opcode))
80   {
81     std::ostringstream oss;
82     oss << "(invalid: " << index << ")";
83     return oss.str();
84   }
85
86   return mio::tflite::opcode_name(opcode);
87 }
88
89 bool Reader::select_subgraph(uint32_t sgindex)
90 {
91   _subgraph_index = sgindex;
92   _tensors = nullptr;
93   _operators = nullptr;
94
95   _inputs.clear();
96   _outputs.clear();
97
98   if (_subgraphs->Length() <= sgindex)
99   {
100     assert(false);
101     return false;
102   }
103
104   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
105
106   auto name = subgraph->name();
107   _subgraph_name = name ? name->c_str() : "(noname)";
108
109   _tensors = subgraph->tensors();
110   _operators = subgraph->operators();
111
112   _inputs = as_index_vector(subgraph->inputs());
113   _outputs = as_index_vector(subgraph->outputs());
114
115   return true;
116 }
117
118 } // namespace tflread