f9782d9effb6fcb5199e995e63ad5565da95a26a
[platform/core/ml/nnfw.git] / compiler / tfldump / src / Read.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Read.h"
18
19 #include <sstream>
20 #include <string>
21
22 namespace tflread
23 {
24
25 bool is_valid(const tflite::OperatorCode *opcode)
26 {
27   tflite::BuiltinOperator code = opcode->builtin_code();
28   return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
29 }
30
31 bool is_custom(const tflite::OperatorCode *opcode)
32 {
33   tflite::BuiltinOperator code = opcode->builtin_code();
34   return (code == tflite::BuiltinOperator_CUSTOM);
35 }
36
37 std::string opcode_name(const tflite::OperatorCode *opcode)
38 {
39   assert(opcode);
40
41   if (!is_valid(opcode))
42   {
43     std::ostringstream oss;
44     oss << "(invalid)";
45     return oss.str();
46   }
47
48   if (is_custom(opcode))
49   {
50     if (!opcode->custom_code())
51       return "(invalid custom)";
52
53     std::string custom_op = "CUSTOM(";
54     custom_op += opcode->custom_code()->c_str();
55     custom_op += ")";
56     return custom_op;
57   }
58
59   tflite::BuiltinOperator code = opcode->builtin_code();
60   return tflite::EnumNameBuiltinOperator(code);
61 }
62
63 const char *tensor_type(const tflite::Tensor *tensor)
64 {
65   return tflite::EnumNameTensorType(tensor->type());
66 }
67
68 const char *tensor_name(const tflite::Tensor *tensor)
69 {
70   static const char *kEmptyTensorName = "(noname)";
71
72   auto name = tensor->name();
73   if (name)
74     return name->c_str();
75
76   return kEmptyTensorName;
77 }
78
79 Reader::Reader(const tflite::Model *model)
80 {
81   _version = model->version();
82   _subgraphs = model->subgraphs();
83   _buffers = model->buffers();
84
85   auto opcodes = model->operator_codes();
86   for (const ::tflite::OperatorCode *opcode : *opcodes)
87   {
88     _op_codes.push_back(opcode);
89   }
90 }
91
92 size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data)
93 {
94   *buff_data = nullptr;
95
96   if (buf_idx == 0)
97     return 0;
98
99   if (auto *buffer = (*_buffers)[buf_idx])
100   {
101     if (auto *array = buffer->data())
102     {
103       if (size_t size = array->size())
104       {
105         *buff_data = reinterpret_cast<const uint8_t *>(array->data());
106         return size;
107       }
108     }
109   }
110
111   return 0;
112 }
113
114 tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
115 {
116   uint32_t index = op->opcode_index();
117   assert(index < _op_codes.size());
118   const tflite::OperatorCode *opcode = _op_codes.at(index);
119
120   return opcode->builtin_code();
121 }
122
123 std::string Reader::opcode_name(const tflite::Operator *op) const
124 {
125   uint32_t index = op->opcode_index();
126   assert(index < _op_codes.size());
127   const tflite::OperatorCode *opcode = _op_codes.at(index);
128
129   if (!is_valid(opcode))
130   {
131     std::ostringstream oss;
132     oss << "(invalid: " << index << ")";
133     return oss.str();
134   }
135
136   return tflread::opcode_name(opcode);
137 }
138
139 bool Reader::select_subgraph(uint32_t sgindex)
140 {
141   _subgraph_index = sgindex;
142   _tensors = nullptr;
143   _operators = nullptr;
144
145   _inputs.clear();
146   _outputs.clear();
147
148   if (_subgraphs->Length() <= sgindex)
149   {
150     assert(false);
151     return false;
152   }
153
154   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
155
156   auto name = subgraph->name();
157   _subgraph_name = name ? name->c_str() : "(noname)";
158
159   _tensors = subgraph->tensors();
160   _operators = subgraph->operators();
161
162   _inputs = as_index_vector(subgraph->inputs());
163   _outputs = as_index_vector(subgraph->outputs());
164
165   return true;
166 }
167
168 } // namespace tflread