Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / tflite2circle / src / DataLookup.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "DataLookup.h"
18 #include "BuildBuiltinOptions.h"
19
20 namespace tflite2circle
21 {
22
23 circle::BuiltinOperator get_circle_builtin_code(tflite::BuiltinOperator tfl_bop)
24 {
25   switch (tfl_bop)
26   {
27 #define TFL_OPERATOR(OP)             \
28   case tflite::BuiltinOperator_##OP: \
29     return circle::BuiltinOperator_##OP;
30 #include "TFLOperator.lst"
31 #undef TFL_OPERATOR
32     default:
33       throw std::runtime_error("tflite2circle: wrong op");
34   }
35 }
36
37 circle::TensorType get_circle_tensortype(tflite::TensorType tfl_tt)
38 {
39   switch (tfl_tt)
40   {
41 #define TFL_TENSORTYPE(TENSORTYPE)      \
42   case tflite::TensorType_##TENSORTYPE: \
43     return circle::TensorType_##TENSORTYPE;
44 #include "TFLTensorType.lst"
45 #undef TFL_TENSORTYPE
46     default:
47       throw std::runtime_error("tflite2circle: wrong tensor type");
48   }
49 }
50
51 circle::Padding get_circle_padding(tflite::Padding tfl_p)
52 {
53   switch (tfl_p)
54   {
55     case tflite::Padding_SAME:
56       return circle::Padding_SAME;
57     case tflite::Padding_VALID:
58       return circle::Padding_VALID;
59     default:
60       throw std::runtime_error("tflite2circle: wrong padding");
61   }
62 }
63
64 circle::ActivationFunctionType
65 get_circle_activation_function_type(tflite::ActivationFunctionType tfl_aft)
66 {
67   switch (tfl_aft)
68   {
69 #define TFL_ACTIVATION_FUNCTION(TYPE)         \
70   case tflite::ActivationFunctionType_##TYPE: \
71     return circle::ActivationFunctionType_##TYPE;
72 #include "TFLActivationFunctionType.lst"
73 #undef TFL_ACTIVATION_FUNCTION
74     default:
75       throw std::runtime_error("tflite2circle: wrong activation function type.");
76   }
77 }
78
79 flatbuffers::Offset<void> get_circle_builtin_options(flatbuffers::FlatBufferBuilder &fb,
80                                                      const tflite::Operator *op)
81 {
82   auto tflite_builtin_options_type = op->builtin_options_type();
83   switch (tflite_builtin_options_type)
84   {
85     case tflite::BuiltinOptions_NONE:
86       return flatbuffers::Offset<void>();
87 #define TFL_BUILTIN_OPTIONS(TYPE)     \
88   case tflite::BuiltinOptions_##TYPE: \
89     return build_circle_##TYPE(fb, op).Union();
90 #include "TFLBuiltinOptions.lst"
91 #undef TFL_BUILTIN_OPTIONS
92     default:
93       throw std::runtime_error("tflite2circle: wrong builtin options type.");
94   }
95 }
96
97 circle::BuiltinOptions get_circle_builtin_options_type(const tflite::Operator *op)
98 {
99   switch (op->builtin_options_type())
100   {
101     case tflite::BuiltinOptions_NONE:
102       return circle::BuiltinOptions_NONE;
103 #define TFL_BUILTIN_OPTIONS(TYPE)     \
104   case tflite::BuiltinOptions_##TYPE: \
105     return circle::BuiltinOptions_##TYPE;
106 #include "TFLBuiltinOptions.lst"
107 #undef TFL_BUILTIN_OPTIONS
108     default:
109       throw std::runtime_error("tflite2circle: wrong builtin options type.");
110   }
111 }
112
113 circle::MirrorPadMode get_circle_mirrorpad_mode(tflite::MirrorPadMode tfl_mode)
114 {
115   switch (tfl_mode)
116   {
117     case tflite::MirrorPadMode_REFLECT:
118       return circle::MirrorPadMode_REFLECT;
119     case tflite::MirrorPadMode_SYMMETRIC:
120       return circle::MirrorPadMode_SYMMETRIC;
121     default:
122       throw std::runtime_error("tflite2circle: wrong mirrorpad mode.");
123   }
124 }
125
126 circle::DimensionType get_circle_dimension_type(tflite::DimensionType tfl_dim_type)
127 {
128   switch (tfl_dim_type)
129   {
130     case tflite::DimensionType_DENSE:
131       return circle::DimensionType_DENSE;
132     case tflite::DimensionType_SPARSE_CSR:
133       return circle::DimensionType_SPARSE_CSR;
134     default:
135       throw std::runtime_error("tflite2circle: wrong dimension type.");
136   }
137 }
138
139 flatbuffers::Offset<void>
140 get_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb,
141                                const tflite::DimensionMetadata *dm,
142                                const tflite::SparseIndexVector &tfl_sparse_index_vector_type)
143 {
144   switch (tfl_sparse_index_vector_type)
145   {
146     case tflite::SparseIndexVector_NONE:
147       return flatbuffers::Offset<void>();
148     case tflite::SparseIndexVector_Int32Vector:
149     {
150       auto values_vec_int32 =
151           std::vector<int32_t>{dm->array_segments_as_Int32Vector()->values()->begin(),
152                                dm->array_segments_as_Int32Vector()->values()->end()};
153       auto values_int32 = fb.CreateVector(values_vec_int32);
154       circle::Int32VectorBuilder int32_vector_builder{fb};
155       int32_vector_builder.add_values(values_int32);
156       return int32_vector_builder.Finish().Union();
157     }
158     case tflite::SparseIndexVector_Uint16Vector:
159     {
160       auto values_vec_uint16 =
161           std::vector<uint16_t>{dm->array_segments_as_Uint16Vector()->values()->begin(),
162                                 dm->array_segments_as_Uint16Vector()->values()->end()};
163       auto values_uint16 = fb.CreateVector(values_vec_uint16);
164       circle::Uint16VectorBuilder uint16_vector_builder{fb};
165       uint16_vector_builder.add_values(values_uint16);
166       return uint16_vector_builder.Finish().Union();
167     }
168     case tflite::SparseIndexVector_Uint8Vector:
169     {
170       auto values_vec_uint8 =
171           std::vector<uint8_t>{dm->array_segments_as_Uint8Vector()->values()->begin(),
172                                dm->array_segments_as_Uint8Vector()->values()->end()};
173       auto values_uint8 = fb.CreateVector(values_vec_uint8);
174       circle::Uint8VectorBuilder uint8_vector_builder{fb};
175       uint8_vector_builder.add_values(values_uint8);
176       return uint8_vector_builder.Finish().Union();
177     }
178     default:
179       throw std::runtime_error("tflite2circle: wrong SparseIndexVector type.");
180   }
181 }
182
183 circle::SparseIndexVector
184 get_circle_sparse_index_vector_type(const tflite::SparseIndexVector &tfl_sparse_index_vector_type)
185 {
186   switch (tfl_sparse_index_vector_type)
187   {
188     case tflite::SparseIndexVector_NONE:
189       return circle::SparseIndexVector_NONE;
190     case tflite::SparseIndexVector_Int32Vector:
191       return circle::SparseIndexVector_Int32Vector;
192     case tflite::SparseIndexVector_Uint16Vector:
193       return circle::SparseIndexVector_Uint16Vector;
194     case tflite::SparseIndexVector_Uint8Vector:
195       return circle::SparseIndexVector_Uint8Vector;
196     default:
197       throw std::runtime_error("tflite2circle: wrong SparseIndexVector type.");
198   }
199 }
200
201 } // namespace tflite2circle