Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / tfldump / src / Read.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Read.h"
18
19 #include <sstream>
20 #include <string>
21
22 namespace tflread
23 {
24
25 // This will provide v3/v3a format neutral BuiltinOperator
26 tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
27 {
28   assert(opcode != nullptr);
29   int8_t dp_code = opcode->deprecated_builtin_code();
30   if (dp_code < 127 && dp_code >= 0)
31     return tflite::BuiltinOperator(dp_code);
32   return opcode->builtin_code();
33 }
34
35 bool is_valid(const tflite::OperatorCode *opcode)
36 {
37   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
38   return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
39 }
40
41 bool is_custom(const tflite::OperatorCode *opcode)
42 {
43   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
44   return (code == tflite::BuiltinOperator_CUSTOM);
45 }
46
47 std::string opcode_name(const tflite::OperatorCode *opcode)
48 {
49   assert(opcode);
50
51   if (!is_valid(opcode))
52   {
53     std::ostringstream oss;
54     oss << "(invalid)";
55     return oss.str();
56   }
57
58   if (is_custom(opcode))
59   {
60     if (!opcode->custom_code())
61       return "(invalid custom)";
62
63     std::string custom_op = "CUSTOM(";
64     custom_op += opcode->custom_code()->c_str();
65     custom_op += ")";
66     return custom_op;
67   }
68
69   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
70   return tflite::EnumNameBuiltinOperator(code);
71 }
72
73 const char *tensor_type(const tflite::Tensor *tensor)
74 {
75   return tflite::EnumNameTensorType(tensor->type());
76 }
77
78 const char *tensor_name(const tflite::Tensor *tensor)
79 {
80   static const char *kEmptyTensorName = "(noname)";
81
82   auto name = tensor->name();
83   if (name)
84     return name->c_str();
85
86   return kEmptyTensorName;
87 }
88
89 Reader::Reader(const tflite::Model *model)
90 {
91   _version = model->version();
92   _subgraphs = model->subgraphs();
93   _buffers = model->buffers();
94   _metadata = model->metadata();
95   _signaturedefs = model->signature_defs();
96
97   auto opcodes = model->operator_codes();
98   for (const ::tflite::OperatorCode *opcode : *opcodes)
99   {
100     _op_codes.push_back(opcode);
101   }
102 }
103
104 size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data)
105 {
106   *buff_data = nullptr;
107
108   if (buf_idx == 0)
109     return 0;
110
111   if (auto *buffer = (*_buffers)[buf_idx])
112   {
113     if (auto *array = buffer->data())
114     {
115       if (size_t size = array->size())
116       {
117         *buff_data = reinterpret_cast<const uint8_t *>(array->data());
118         return size;
119       }
120     }
121   }
122
123   return 0;
124 }
125
126 tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
127 {
128   uint32_t index = op->opcode_index();
129   assert(index < _op_codes.size());
130   const tflite::OperatorCode *opcode = _op_codes.at(index);
131
132   return tflread::builtin_code_neutral(opcode);
133 }
134
135 std::string Reader::opcode_name(const tflite::Operator *op) const
136 {
137   uint32_t index = op->opcode_index();
138   assert(index < _op_codes.size());
139   const tflite::OperatorCode *opcode = _op_codes.at(index);
140
141   if (!is_valid(opcode))
142   {
143     std::ostringstream oss;
144     oss << "(invalid: " << index << ")";
145     return oss.str();
146   }
147
148   return tflread::opcode_name(opcode);
149 }
150
151 bool Reader::select_subgraph(uint32_t sgindex)
152 {
153   _subgraph_index = sgindex;
154   _tensors = nullptr;
155   _operators = nullptr;
156
157   _inputs.clear();
158   _outputs.clear();
159
160   if (_subgraphs->Length() <= sgindex)
161   {
162     assert(false);
163     return false;
164   }
165
166   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
167
168   auto name = subgraph->name();
169   _subgraph_name = name ? name->c_str() : "(noname)";
170
171   _tensors = subgraph->tensors();
172   _operators = subgraph->operators();
173
174   _inputs = as_index_vector(subgraph->inputs());
175   _outputs = as_index_vector(subgraph->outputs());
176
177   return true;
178 }
179
180 } // namespace tflread