Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / tfl-inspect / src / Reader.cpp
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Reader.h"
18
19 #include <cassert>
20 #include <sstream>
21 #include <string>
22
23 namespace tflinspect
24 {
25
26 // This will provide v3/v3a format neutral BuiltinOperator
27 tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
28 {
29   assert(opcode != nullptr);
30   int8_t dp_code = opcode->deprecated_builtin_code();
31   // 127 is max of int8_t which is upper bound of v3 builtin_code
32   // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
33   if (dp_code < 127 && dp_code >= 0)
34     return tflite::BuiltinOperator(dp_code);
35   return opcode->builtin_code();
36 }
37
38 bool is_valid(const tflite::OperatorCode *opcode)
39 {
40   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
41   return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
42 }
43
44 bool is_custom(const tflite::OperatorCode *opcode)
45 {
46   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
47   return (code == tflite::BuiltinOperator_CUSTOM);
48 }
49
50 std::string opcode_name(const tflite::OperatorCode *opcode)
51 {
52   assert(opcode);
53
54   if (!is_valid(opcode))
55   {
56     std::ostringstream oss;
57     oss << "(invalid)";
58     return oss.str();
59   }
60
61   if (is_custom(opcode))
62   {
63     if (!opcode->custom_code())
64       return "(invalid custom)";
65
66     std::string custom_op = "CUSTOM(";
67     custom_op += opcode->custom_code()->c_str();
68     custom_op += ")";
69     return custom_op;
70   }
71
72   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
73   return tflite::EnumNameBuiltinOperator(code);
74 }
75
76 const char *tensor_type(const tflite::Tensor *tensor)
77 {
78   return tflite::EnumNameTensorType(tensor->type());
79 }
80
81 const char *tensor_name(const tflite::Tensor *tensor)
82 {
83   static const char *kEmptyTensorName = "(noname)";
84
85   auto name = tensor->name();
86   if (name)
87     return name->c_str();
88
89   return kEmptyTensorName;
90 }
91
92 Reader::Reader(const tflite::Model *model)
93 {
94   _subgraphs = model->subgraphs();
95   _buffers = model->buffers();
96
97   auto opcodes = model->operator_codes();
98   for (const ::tflite::OperatorCode *opcode : *opcodes)
99   {
100     _op_codes.push_back(opcode);
101   }
102 }
103
104 size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data)
105 {
106   if (buff_data != nullptr)
107   {
108     *buff_data = nullptr;
109   }
110
111   if (buf_idx == 0)
112     return 0;
113
114   if (auto *buffer = (*_buffers)[buf_idx])
115   {
116     if (auto *array = buffer->data())
117     {
118       if (size_t size = array->size())
119       {
120         if (buff_data != nullptr)
121         {
122           *buff_data = reinterpret_cast<const uint8_t *>(array->data());
123         }
124         return size;
125       }
126     }
127   }
128
129   return 0;
130 }
131
132 tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
133 {
134   uint32_t index = op->opcode_index();
135   assert(index < _op_codes.size());
136   const tflite::OperatorCode *opcode = _op_codes.at(index);
137
138   return tflinspect::builtin_code_neutral(opcode);
139 }
140
141 std::string Reader::opcode_name(const tflite::Operator *op) const
142 {
143   uint32_t index = op->opcode_index();
144   assert(index < _op_codes.size());
145   const tflite::OperatorCode *opcode = _op_codes.at(index);
146
147   if (!is_valid(opcode))
148   {
149     std::ostringstream oss;
150     oss << "(invalid: " << index << ")";
151     return oss.str();
152   }
153
154   return tflinspect::opcode_name(opcode);
155 }
156
157 bool Reader::select_subgraph(uint32_t sgindex)
158 {
159   _tensors = nullptr;
160   _operators = nullptr;
161
162   _inputs.clear();
163   _outputs.clear();
164
165   if (_subgraphs->Length() <= sgindex)
166   {
167     assert(false);
168     return false;
169   }
170
171   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
172
173   _tensors = subgraph->tensors();
174   _operators = subgraph->operators();
175
176   _inputs = as_index_vector(subgraph->inputs());
177   _outputs = as_index_vector(subgraph->outputs());
178
179   return true;
180 }
181
182 } // namespace tflinspect