Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / tflchef / tflite / src / TFliteImport.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "TFliteImport.h"
18
19 #include "Convert.h"
20
21 #include <mio_tflite2121/Helper.h>
22
23 #include <sstream>
24
25 namespace tflchef
26 {
27
28 TFliteImport::TFliteImport(const tflite::Model *model)
29 {
30   _subgraphs = model->subgraphs();
31   _buffers = model->buffers();
32
33   auto opcodes = model->operator_codes();
34   for (const ::tflite::OperatorCode *opcode : *opcodes)
35   {
36     _op_codes.push_back(opcode);
37   }
38 }
39
40 bool TFliteImport::select_sub_graph(uint32_t sgindex)
41 {
42   _tensors = nullptr;
43   _operators = nullptr;
44   _inputs.clear();
45   _outputs.clear();
46
47   if (_subgraphs->Length() <= sgindex)
48   {
49     assert(false);
50     return false;
51   }
52
53   const tflite::SubGraph *subgraph = (*_subgraphs)[sgindex];
54
55   _tensors = subgraph->tensors();
56   _operators = subgraph->operators();
57
58   _inputs = as_index_vector(subgraph->inputs());
59   _outputs = as_index_vector(subgraph->outputs());
60
61   return true;
62 }
63
64 tflite::BuiltinOperator TFliteImport::builtin_code(const tflite::Operator *op) const
65 {
66   uint32_t index = op->opcode_index();
67   assert(index < _op_codes.size());
68   const tflite::OperatorCode *opcode = _op_codes.at(index);
69
70   return mio::tflite::builtin_code_neutral(opcode);
71 }
72
73 std::string TFliteImport::opcode_name(const tflite::Operator *op) const
74 {
75   uint32_t index = op->opcode_index();
76   assert(index < _op_codes.size());
77   const tflite::OperatorCode *opcode = _op_codes.at(index);
78
79   if (!mio::tflite::is_valid(opcode))
80   {
81     std::ostringstream oss;
82     oss << "(invalid: " << index << ")";
83     return oss.str();
84   }
85
86   if (mio::tflite::is_custom(opcode))
87   {
88     if (!opcode->custom_code())
89       return "(invalid custom)";
90
91     return opcode->custom_code()->c_str();
92   }
93
94   tflite::BuiltinOperator code = mio::tflite::builtin_code_neutral(opcode);
95   return EnumNameBuiltinOperator(code);
96 }
97
98 size_t TFliteImport::buffer_info(const tflite::Tensor *tensor, const uint8_t **buff_data)
99 {
100   *buff_data = nullptr;
101
102   if (tensor->buffer() == 0)
103     return 0;
104
105   if (auto *buffer = (*_buffers)[tensor->buffer()])
106   {
107     if (auto *array = buffer->data())
108     {
109       if (size_t size = array->size())
110       {
111         *buff_data = reinterpret_cast<const uint8_t *>(array->data());
112         return size;
113       }
114     }
115   }
116
117   return 0;
118 }
119
120 } // namespace tflchef