Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / tfldump / src / Read.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Read.h"
18
19 #include <sstream>
20 #include <string>
21
22 namespace tflread
23 {
24
25 bool is_valid(const tflite::OperatorCode *opcode)
26 {
27   tflite::BuiltinOperator code = opcode->builtin_code();
28   return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
29 }
30
31 bool is_custom(const tflite::OperatorCode *opcode)
32 {
33   tflite::BuiltinOperator code = opcode->builtin_code();
34   return (code == tflite::BuiltinOperator_CUSTOM);
35 }
36
37 std::string opcode_name(const tflite::OperatorCode *opcode)
38 {
39   assert(opcode);
40
41   if (!is_valid(opcode))
42   {
43     std::ostringstream oss;
44     oss << "(invalid)";
45     return oss.str();
46   }
47
48   if (is_custom(opcode))
49   {
50     if (!opcode->custom_code())
51       return "(invalid custom)";
52
53     std::string custom_op = "CUSTOM(";
54     custom_op += opcode->custom_code()->c_str();
55     custom_op += ")";
56     return custom_op;
57   }
58
59   tflite::BuiltinOperator code = opcode->builtin_code();
60   return tflite::EnumNameBuiltinOperator(code);
61 }
62
63 const char *tensor_type(const tflite::Tensor *tensor)
64 {
65   return tflite::EnumNameTensorType(tensor->type());
66 }
67
68 const char *tensor_name(const tflite::Tensor *tensor)
69 {
70   static const char *kEmptyTensorName = "(noname)";
71
72   auto name = tensor->name();
73   if (name)
74     return name->c_str();
75
76   return kEmptyTensorName;
77 }
78
79 Reader::Reader(const tflite::Model *model)
80 {
81   _version = model->version();
82   _subgraphs = model->subgraphs();
83   _buffers = model->buffers();
84   _metadata = model->metadata();
85
86   auto opcodes = model->operator_codes();
87   for (const ::tflite::OperatorCode *opcode : *opcodes)
88   {
89     _op_codes.push_back(opcode);
90   }
91 }
92
93 size_t Reader::buffer_info(uint32_t buf_idx, const uint8_t **buff_data)
94 {
95   *buff_data = nullptr;
96
97   if (buf_idx == 0)
98     return 0;
99
100   if (auto *buffer = (*_buffers)[buf_idx])
101   {
102     if (auto *array = buffer->data())
103     {
104       if (size_t size = array->size())
105       {
106         *buff_data = reinterpret_cast<const uint8_t *>(array->data());
107         return size;
108       }
109     }
110   }
111
112   return 0;
113 }
114
115 tflite::BuiltinOperator Reader::builtin_code(const tflite::Operator *op) const
116 {
117   uint32_t index = op->opcode_index();
118   assert(index < _op_codes.size());
119   const tflite::OperatorCode *opcode = _op_codes.at(index);
120
121   return opcode->builtin_code();
122 }
123
124 std::string Reader::opcode_name(const tflite::Operator *op) const
125 {
126   uint32_t index = op->opcode_index();
127   assert(index < _op_codes.size());
128   const tflite::OperatorCode *opcode = _op_codes.at(index);
129
130   if (!is_valid(opcode))
131   {
132     std::ostringstream oss;
133     oss << "(invalid: " << index << ")";
134     return oss.str();
135   }
136
137   return tflread::opcode_name(opcode);
138 }
139
140 bool Reader::select_subgraph(uint32_t sgindex)
141 {
142   _subgraph_index = sgindex;
143   _tensors = nullptr;
144   _operators = nullptr;
145
146   _inputs.clear();
147   _outputs.clear();
148
149   if (_subgraphs->Length() <= sgindex)
150   {
151     assert(false);
152     return false;
153   }
154
155   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
156
157   auto name = subgraph->name();
158   _subgraph_name = name ? name->c_str() : "(noname)";
159
160   _tensors = subgraph->tensors();
161   _operators = subgraph->operators();
162
163   _inputs = as_index_vector(subgraph->inputs());
164   _outputs = as_index_vector(subgraph->outputs());
165
166   return true;
167 }
168
169 } // namespace tflread