Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / tflchef / tflite / src / TFliteImport.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "TFliteImport.h"
18
19 #include "Convert.h"
20
21 #include <sstream>
22
23 namespace tflchef
24 {
25
26 const char *kEmptyTensorName = "(noname)";
27
28 const char *tensor_type(const tflite::Tensor *tensor)
29 {
30   return tflite::EnumNameTensorType(tensor->type());
31 }
32
33 const char *tensor_name(const tflite::Tensor *tensor)
34 {
35   auto name = tensor->name();
36   if (name)
37     return name->c_str();
38   return kEmptyTensorName;
39 }
40
41 // This will provide v3/v3a format neutral BuiltinOperator
42 tflite::BuiltinOperator builtin_code_neutral(const tflite::OperatorCode *opcode)
43 {
44   assert(opcode != nullptr);
45   int8_t dp_code = opcode->deprecated_builtin_code();
46   // 127 is max of int8_t which is upper bound of v3 builtin_code
47   // NOTE TensorFlow uses 'BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES' for 127
48   if (dp_code < 127 && dp_code >= 0)
49     return tflite::BuiltinOperator(dp_code);
50   return opcode->builtin_code();
51 }
52
53 bool is_valid(const tflite::OperatorCode *opcode)
54 {
55   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
56   return (tflite::BuiltinOperator_MIN <= code && code <= tflite::BuiltinOperator_MAX);
57 }
58
59 bool is_custom(const tflite::OperatorCode *opcode)
60 {
61   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
62   return (code == tflite::BuiltinOperator_CUSTOM);
63 }
64
65 TFliteImport::TFliteImport(const tflite::Model *model)
66 {
67   _subgraphs = model->subgraphs();
68   _buffers = model->buffers();
69
70   auto opcodes = model->operator_codes();
71   for (const ::tflite::OperatorCode *opcode : *opcodes)
72   {
73     _op_codes.push_back(opcode);
74   }
75 }
76
77 bool TFliteImport::select_sub_graph(uint32_t sgindex)
78 {
79   _tensors = nullptr;
80   _operators = nullptr;
81   _inputs.clear();
82   _outputs.clear();
83
84   if (_subgraphs->Length() <= sgindex)
85   {
86     assert(false);
87     return false;
88   }
89
90   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
91
92   _tensors = subgraph->tensors();
93   _operators = subgraph->operators();
94
95   _inputs = as_index_vector(subgraph->inputs());
96   _outputs = as_index_vector(subgraph->outputs());
97
98   return true;
99 }
100
101 tflite::BuiltinOperator TFliteImport::builtin_code(const tflite::Operator *op) const
102 {
103   uint32_t index = op->opcode_index();
104   assert(index < _op_codes.size());
105   const tflite::OperatorCode *opcode = _op_codes.at(index);
106
107   return builtin_code_neutral(opcode);
108 }
109
110 std::string TFliteImport::opcode_name(const tflite::Operator *op) const
111 {
112   uint32_t index = op->opcode_index();
113   assert(index < _op_codes.size());
114   const tflite::OperatorCode *opcode = _op_codes.at(index);
115
116   if (!is_valid(opcode))
117   {
118     std::ostringstream oss;
119     oss << "(invalid: " << index << ")";
120     return oss.str();
121   }
122
123   if (is_custom(opcode))
124   {
125     if (!opcode->custom_code())
126       return "(invalid custom)";
127
128     return opcode->custom_code()->c_str();
129   }
130
131   tflite::BuiltinOperator code = builtin_code_neutral(opcode);
132   return EnumNameBuiltinOperator(code);
133 }
134
135 size_t TFliteImport::buffer_info(const tflite::Tensor *tensor, const uint8_t **buff_data)
136 {
137   *buff_data = nullptr;
138
139   if (tensor->buffer() == 0)
140     return 0;
141
142   if (auto *buffer = (*_buffers)[tensor->buffer()])
143   {
144     if (auto *array = buffer->data())
145     {
146       if (size_t size = array->size())
147       {
148         *buff_data = reinterpret_cast<const uint8_t *>(array->data());
149         return size;
150       }
151     }
152   }
153
154   return 0;
155 }
156
157 } // namespace tflchef