Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / tflite2circle / src / CircleModel.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <iostream>
18 #include <memory>
19
20 #include "CircleModel.h"
21 #include "DataLookup.h"
22
23 namespace tflite2circle
24 {
25
26 template <>
27 Offset<MetaDataBufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
28 {
29   if (tflite_flatbuffer_vec == nullptr)
30     return;
31   std::vector<int32_t> metadata_buffer_vec{tflite_flatbuffer_vec->begin(),
32                                            tflite_flatbuffer_vec->end()};
33   _circle_flatbuffer_vec_offset = fb->CreateVector(metadata_buffer_vec);
34 }
35
36 template <>
37 Offset<BufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
38 {
39   std::vector<flatbuffers::Offset<circle::Buffer>> buffers_vec;
40
41   for (auto it : *tflite_flatbuffer_vec)
42   {
43     flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer_data;
44     if (it->data())
45     {
46       std::vector<uint8_t> data_vec{it->data()->begin(), it->data()->end()};
47       buffer_data = fb->CreateVector(data_vec);
48     }
49     circle::BufferBuilder circle_buffer_builder{*fb};
50     circle_buffer_builder.add_data(buffer_data);
51     auto circle_buffers = circle_buffer_builder.Finish();
52     buffers_vec.emplace_back(circle_buffers);
53   }
54   _circle_flatbuffer_vec_offset = fb->CreateVector(buffers_vec);
55 }
56
57 template <>
58 Offset<SubGraphLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
59 {
60   std::vector<flatbuffers::Offset<circle::SubGraph>> subgprahs_vec;
61
62   for (auto it_sg : *tflite_flatbuffer_vec)
63   {
64     // tensors of subgraph
65     std::vector<flatbuffers::Offset<circle::Tensor>> tensor_vec;
66
67     auto tflite_tensors = it_sg->tensors();
68     for (auto it : *tflite_tensors)
69     {
70       // shape
71       flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
72       if (it->shape())
73       {
74         auto shape_vec = std::vector<int32_t>({it->shape()->begin(), it->shape()->end()});
75         shape = fb->CreateVector(shape_vec);
76       }
77       // name
78       flatbuffers::Offset<flatbuffers::String> name;
79       if (it->name())
80         name = fb->CreateString(it->name()->str());
81       // quantization
82       flatbuffers::Offset<circle::QuantizationParameters> quantization;
83       if (it->quantization())
84       {
85         std::vector<float> tfmin;
86         std::vector<float> tfmax;
87         std::vector<float> tfscale;
88         std::vector<int64_t> tfzerop;
89         flatbuffers::Offset<flatbuffers::Vector<float>> min;
90         flatbuffers::Offset<flatbuffers::Vector<float>> max;
91         flatbuffers::Offset<flatbuffers::Vector<float>> scale;
92         flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point;
93         int32_t quantized_dimension = it->quantization()->quantized_dimension();
94
95         if (it->quantization()->min() && it->quantization()->max())
96         {
97           auto rmin = it->quantization()->min();
98           auto rmax = it->quantization()->max();
99           tfmin = std::vector<float>{rmin->begin(), rmin->end()};
100           tfmax = std::vector<float>{rmax->begin(), rmax->end()};
101           min = fb->CreateVector(tfmin);
102           max = fb->CreateVector(tfmax);
103         }
104
105         if (it->quantization()->scale() && it->quantization()->zero_point())
106         {
107           auto rs = it->quantization()->scale();
108           auto rz = it->quantization()->zero_point();
109           tfscale = std::vector<float>{rs->begin(), rs->end()};
110           tfzerop = std::vector<int64_t>{rz->begin(), rz->end()};
111           scale = fb->CreateVector(tfscale);
112           zero_point = fb->CreateVector(tfzerop);
113         }
114
115         quantization = circle::CreateQuantizationParameters(*fb, min, max, scale, zero_point,
116                                                             circle::QuantizationDetails_NONE, 0,
117                                                             quantized_dimension);
118       }
119       // is_variable
120       bool is_variable = it->is_variable();
121
122       flatbuffers::Offset<circle::SparsityParameters> sparsity;
123       // sparsity
124       if (it->sparsity())
125       {
126         flatbuffers::Offset<flatbuffers::Vector<int32_t>> traversal_order;
127         flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_map;
128         flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<circle::DimensionMetadata>>>
129             dim_metadata;
130
131         // traversal_order
132         if (it->sparsity()->traversal_order())
133         {
134           auto traversal_order_vec = std::vector<int32_t>{
135               it->sparsity()->traversal_order()->begin(), it->sparsity()->traversal_order()->end()};
136           traversal_order = fb->CreateVector(traversal_order_vec);
137         }
138
139         // block_map
140         if (it->sparsity()->block_map())
141         {
142           auto block_map_vec = std::vector<int32_t>{it->sparsity()->block_map()->begin(),
143                                                     it->sparsity()->block_map()->end()};
144           block_map = fb->CreateVector(block_map_vec);
145         }
146
147         // dim_metadata
148         std::vector<flatbuffers::Offset<circle::DimensionMetadata>> dim_metadata_vec;
149         auto tflite_dim_metadata = it->sparsity()->dim_metadata();
150         for (auto it : *tflite_dim_metadata)
151         {
152           // array_segments
153           auto tflite_array_segments_type = it->array_segments_type();
154           auto circle_array_segments =
155               get_circle_sparse_index_vector(*fb, it, tflite_array_segments_type);
156           auto circle_array_segments_type =
157               get_circle_sparse_index_vector_type(tflite_array_segments_type);
158
159           // array_indices
160           auto tflite_array_indices_type = it->array_indices_type();
161           auto circle_array_indices =
162               get_circle_sparse_index_vector(*fb, it, tflite_array_indices_type);
163           auto circle_array_indices_type =
164               get_circle_sparse_index_vector_type(tflite_array_indices_type);
165
166           auto circle_dim_metadata_builder = circle::DimensionMetadataBuilder{*fb};
167
168           circle_dim_metadata_builder.add_format(get_circle_dimension_type(it->format()));
169           circle_dim_metadata_builder.add_dense_size(it->dense_size());
170           circle_dim_metadata_builder.add_array_segments(circle_array_segments);
171           circle_dim_metadata_builder.add_array_segments_type(circle_array_segments_type);
172           circle_dim_metadata_builder.add_array_indices(circle_array_indices);
173           circle_dim_metadata_builder.add_array_indices_type(circle_array_indices_type);
174           auto dim_metadata = circle_dim_metadata_builder.Finish();
175           dim_metadata_vec.emplace_back(dim_metadata);
176         }
177         dim_metadata = fb->CreateVector(dim_metadata_vec);
178
179         sparsity = circle::CreateSparsityParameters(*fb, traversal_order, block_map, dim_metadata);
180       }
181
182       // shape signature
183       flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape_signature;
184       if (it->shape_signature())
185       {
186         auto shape_signature_vec =
187             std::vector<int32_t>({it->shape_signature()->begin(), it->shape_signature()->end()});
188         shape_signature = fb->CreateVector(shape_signature_vec);
189       }
190
191       circle::TensorBuilder tensor_builder{*fb};
192       tensor_builder.add_shape(shape);
193       tensor_builder.add_type(get_circle_tensortype(it->type()));
194       tensor_builder.add_buffer(it->buffer());
195       tensor_builder.add_name(name);
196       tensor_builder.add_quantization(quantization);
197       tensor_builder.add_is_variable(is_variable);
198       tensor_builder.add_sparsity(sparsity);
199       tensor_builder.add_shape_signature(shape_signature);
200       auto tensor = tensor_builder.Finish();
201       tensor_vec.emplace_back(tensor);
202     }
203     auto circle_tensors = fb->CreateVector(tensor_vec);
204
205     // inputs of subgraph
206     auto tflite_inputs = it_sg->inputs();
207     std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
208
209     auto circle_inputs = fb->CreateVector(input_vec);
210
211     // outputs of subgraph
212     auto tflite_outputs = it_sg->outputs();
213     std::vector<int32_t> output_vec{tflite_outputs->begin(), tflite_outputs->end()};
214
215     auto circle_outputs = fb->CreateVector(output_vec);
216
217     // operators of subgraph
218     std::vector<flatbuffers::Offset<circle::Operator>> operator_vec;
219
220     auto tflite_operators = it_sg->operators();
221     for (auto it : *tflite_operators)
222     {
223       // inputs
224       std::vector<int32_t> input_vec{it->inputs()->begin(), it->inputs()->end()};
225       auto circle_inputs = fb->CreateVector(input_vec);
226       // outputs
227       std::vector<int32_t> output_vec{it->outputs()->begin(), it->outputs()->end()};
228       auto circle_outputs = fb->CreateVector(output_vec);
229       // builtin options
230       auto circle_builtin_options = get_circle_builtin_options(*fb, it);
231       auto circle_builtin_options_type = get_circle_builtin_options_type(it);
232       // custom options
233       flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
234       if (it->custom_options())
235       {
236         std::vector<uint8_t> custom_options_vec{it->custom_options()->begin(),
237                                                 it->custom_options()->end()};
238         circle_custom_options = fb->CreateVector(custom_options_vec);
239       }
240       // custom options format
241       // TODO Make get_circle_custom_options_format
242       assert(it->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS);
243       auto circle_custom_options_format = circle::CustomOptionsFormat_FLEXBUFFERS;
244
245       circle::OperatorBuilder operator_builder{*fb};
246       operator_builder.add_opcode_index(it->opcode_index());
247       operator_builder.add_inputs(circle_inputs);
248       operator_builder.add_outputs(circle_outputs);
249       operator_builder.add_builtin_options(circle_builtin_options);
250       operator_builder.add_builtin_options_type(circle_builtin_options_type);
251       operator_builder.add_custom_options(circle_custom_options);
252       operator_builder.add_custom_options_format(circle_custom_options_format);
253       // TODO mutating_variable_inputs
254       auto opeartor = operator_builder.Finish();
255       operator_vec.emplace_back(opeartor);
256     }
257     auto circle_operators = fb->CreateVector(operator_vec);
258
259     // name of subgraph
260     auto subgraphs_name = fb->CreateString(it_sg->name());
261
262     // subgraphs
263     auto circle_subgraph_builder = circle::SubGraphBuilder{*fb};
264
265     circle_subgraph_builder.add_tensors(circle_tensors);
266     circle_subgraph_builder.add_inputs(circle_inputs);
267     circle_subgraph_builder.add_outputs(circle_outputs);
268     circle_subgraph_builder.add_operators(circle_operators);
269     circle_subgraph_builder.add_name(subgraphs_name);
270     circle_subgraph_builder.add_data_format(circle::DataFormat_CHANNELS_LAST);
271
272     auto circle_subgraph = circle_subgraph_builder.Finish();
273     subgprahs_vec.emplace_back(circle_subgraph);
274   }
275   _circle_flatbuffer_vec_offset = fb->CreateVector(subgprahs_vec);
276 }
277
278 template <>
279 Offset<OperatorCodeLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
280 {
281   std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
282
283   for (auto it : *tflite_flatbuffer_vec)
284   {
285     auto custom_code = fb->CreateString(it->custom_code());
286     circle::OperatorCodeBuilder operator_code_builder{*fb};
287     operator_code_builder.add_builtin_code(get_circle_builtin_code(it->builtin_code()));
288     operator_code_builder.add_custom_code(custom_code);
289     operator_code_builder.add_version(it->version());
290     auto code = operator_code_builder.Finish();
291     operator_code_vec.emplace_back(code);
292   }
293   _circle_flatbuffer_vec_offset = fb->CreateVector(operator_code_vec);
294 }
295
296 CircleModel::CircleModel(FlatBufBuilder &fb, TFLModel &model)
297     : _version{0}, _description{fb->CreateString("nnpackage")}, _fb{fb}
298 {
299   const tflite::Model *tfl_model = model.load_model();
300   // verify flatbuffers
301   flatbuffers::Verifier verifier{reinterpret_cast<const uint8_t *>(model._data.data()),
302                                  model._data.size()};
303   if (!tflite::VerifyModelBuffer(verifier))
304   {
305     throw std::runtime_error("ERROR: Failed to verify tflite");
306   }
307
308   _operator_codes_offset =
309       std::make_unique<Offset<OperatorCodeLink>>(fb, tfl_model->operator_codes());
310   _subGraphs_offset = std::make_unique<Offset<SubGraphLink>>(fb, tfl_model->subgraphs());
311   _buffers_offset = std::make_unique<Offset<BufferLink>>(fb, tfl_model->buffers());
312   _metadata_buffer_offset =
313       std::make_unique<Offset<MetaDataBufferLink>>(fb, tfl_model->metadata_buffer());
314   model_build();
315 }
316
317 void CircleModel::model_build(void) const
318 {
319   circle::ModelBuilder model_builder{*_fb};
320
321   model_builder.add_version(_version);
322   model_builder.add_description(_description);
323   model_builder.add_operator_codes(_operator_codes_offset->offset());
324   model_builder.add_subgraphs(_subGraphs_offset->offset());
325   model_builder.add_buffers(_buffers_offset->offset());
326   model_builder.add_metadata_buffer(_metadata_buffer_offset->offset());
327
328   auto model = model_builder.Finish();
329   circle::FinishModelBuffer(*_fb, model);
330 }
331
332 const char *CircleModel::base(void) const
333 {
334   return reinterpret_cast<const char *>(_fb->GetBufferPointer());
335 }
336
337 size_t CircleModel::size(void) const { return _fb->GetSize(); }
338
339 } // namespace tflite2circle