cb4437a49553d9961feae3e015db0c5db1009ae5
[platform/core/ml/nnfw.git] / compiler / tflite2circle / src / CircleModel.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <iostream>
18 #include <memory>
19
20 #include "CircleModel.h"
21 #include "DataLookup.h"
22
23 namespace tflite2circle
24 {
25
26 template <>
27 Offset<MetaDataBufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
28 {
29   if (tflite_flatbuffer_vec == nullptr)
30     return;
31   std::vector<int32_t> metadata_buffer_vec{tflite_flatbuffer_vec->begin(),
32                                            tflite_flatbuffer_vec->end()};
33   _circle_flatbuffer_vec_offset = fb->CreateVector(metadata_buffer_vec);
34 }
35
36 template <>
37 Offset<BufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
38 {
39   std::vector<flatbuffers::Offset<circle::Buffer>> buffers_vec;
40
41   for (auto it : *tflite_flatbuffer_vec)
42   {
43     flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer_data;
44     if (it->data())
45     {
46       std::vector<uint8_t> data_vec{it->data()->begin(), it->data()->end()};
47       buffer_data = fb->CreateVector(data_vec);
48     }
49     circle::BufferBuilder circle_buffer_builder{*fb};
50     circle_buffer_builder.add_data(buffer_data);
51     auto circle_buffers = circle_buffer_builder.Finish();
52     buffers_vec.emplace_back(circle_buffers);
53   }
54   _circle_flatbuffer_vec_offset = fb->CreateVector(buffers_vec);
55 }
56
57 template <>
58 Offset<SubGraphLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
59 {
60   std::vector<flatbuffers::Offset<circle::SubGraph>> subgprahs_vec;
61
62   for (auto it_sg : *tflite_flatbuffer_vec)
63   {
64     // tensors of subgraph
65     std::vector<flatbuffers::Offset<circle::Tensor>> tensor_vec;
66
67     auto tflite_tensors = it_sg->tensors();
68     for (auto it : *tflite_tensors)
69     {
70       // shape
71       flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
72       if (it->shape())
73       {
74         auto shape_vec = std::vector<int32_t>({it->shape()->begin(), it->shape()->end()});
75         shape = fb->CreateVector(shape_vec);
76       }
77       // name
78       flatbuffers::Offset<flatbuffers::String> name;
79       if (it->name())
80         name = fb->CreateString(it->name()->str());
81       // quantization
82       flatbuffers::Offset<circle::QuantizationParameters> quantization;
83       if (it->quantization())
84       {
85         std::vector<float> tfmin;
86         std::vector<float> tfmax;
87         std::vector<float> tfscale;
88         std::vector<int64_t> tfzerop;
89         flatbuffers::Offset<flatbuffers::Vector<float>> min;
90         flatbuffers::Offset<flatbuffers::Vector<float>> max;
91         flatbuffers::Offset<flatbuffers::Vector<float>> scale;
92         flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point;
93         int32_t quantized_dimension = it->quantization()->quantized_dimension();
94
95         if (it->quantization()->min() && it->quantization()->max())
96         {
97           auto rmin = it->quantization()->min();
98           auto rmax = it->quantization()->max();
99           tfmin = std::vector<float>{rmin->begin(), rmin->end()};
100           tfmax = std::vector<float>{rmax->begin(), rmax->end()};
101           min = fb->CreateVector(tfmin);
102           max = fb->CreateVector(tfmax);
103         }
104
105         if (it->quantization()->scale() && it->quantization()->zero_point())
106         {
107           auto rs = it->quantization()->scale();
108           auto rz = it->quantization()->zero_point();
109           tfscale = std::vector<float>{rs->begin(), rs->end()};
110           tfzerop = std::vector<int64_t>{rz->begin(), rz->end()};
111           scale = fb->CreateVector(tfscale);
112           zero_point = fb->CreateVector(tfzerop);
113         }
114
115         quantization = circle::CreateQuantizationParameters(*fb, min, max, scale, zero_point,
116                                                             circle::QuantizationDetails_NONE, 0,
117                                                             quantized_dimension);
118       }
119       // is_variable
120       bool is_variable = it->is_variable();
121
122       circle::TensorBuilder tensor_builder{*fb};
123       tensor_builder.add_shape(shape);
124       tensor_builder.add_type(get_circle_tensortype(it->type()));
125       tensor_builder.add_buffer(it->buffer());
126       tensor_builder.add_name(name);
127       tensor_builder.add_quantization(quantization);
128       tensor_builder.add_is_variable(is_variable);
129       auto tensor = tensor_builder.Finish();
130       tensor_vec.emplace_back(tensor);
131     }
132     auto circle_tensors = fb->CreateVector(tensor_vec);
133
134     // inputs of subgraph
135     auto tflite_inputs = it_sg->inputs();
136     std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
137
138     auto circle_inputs = fb->CreateVector(input_vec);
139
140     // outputs of subgraph
141     auto tflite_outputs = it_sg->outputs();
142     std::vector<int32_t> output_vec{tflite_outputs->begin(), tflite_outputs->end()};
143
144     auto circle_outputs = fb->CreateVector(output_vec);
145
146     // operators of subgraph
147     std::vector<flatbuffers::Offset<circle::Operator>> operator_vec;
148
149     auto tflite_operators = it_sg->operators();
150     for (auto it : *tflite_operators)
151     {
152       // inputs
153       std::vector<int32_t> input_vec{it->inputs()->begin(), it->inputs()->end()};
154       auto circle_inputs = fb->CreateVector(input_vec);
155       // outputs
156       std::vector<int32_t> output_vec{it->outputs()->begin(), it->outputs()->end()};
157       auto circle_outputs = fb->CreateVector(output_vec);
158       // builtin options
159       auto circle_builtin_options = get_circle_builtin_options(*fb, it);
160       auto circle_builtin_options_type = get_circle_builtin_options_type(it);
161       // custom options
162       flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
163       if (it->custom_options())
164       {
165         std::vector<uint8_t> custom_options_vec{it->custom_options()->begin(),
166                                                 it->custom_options()->end()};
167         circle_custom_options = fb->CreateVector(custom_options_vec);
168       }
169       // custom options format
170       // TODO Make get_circle_custom_options_format
171       assert(it->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS);
172       auto circle_custom_options_format = circle::CustomOptionsFormat_FLEXBUFFERS;
173
174       circle::OperatorBuilder operator_builder{*fb};
175       operator_builder.add_opcode_index(it->opcode_index());
176       operator_builder.add_inputs(circle_inputs);
177       operator_builder.add_outputs(circle_outputs);
178       operator_builder.add_builtin_options(circle_builtin_options);
179       operator_builder.add_builtin_options_type(circle_builtin_options_type);
180       operator_builder.add_custom_options(circle_custom_options);
181       operator_builder.add_custom_options_format(circle_custom_options_format);
182       // TODO mutating_variable_inputs
183       auto opeartor = operator_builder.Finish();
184       operator_vec.emplace_back(opeartor);
185     }
186     auto circle_operators = fb->CreateVector(operator_vec);
187
188     // name of subgraph
189     auto subgraphs_name = fb->CreateString(it_sg->name());
190
191     // subgraphs
192     auto circle_subgraph_builder = circle::SubGraphBuilder{*fb};
193
194     circle_subgraph_builder.add_tensors(circle_tensors);
195     circle_subgraph_builder.add_inputs(circle_inputs);
196     circle_subgraph_builder.add_outputs(circle_outputs);
197     circle_subgraph_builder.add_operators(circle_operators);
198     circle_subgraph_builder.add_name(subgraphs_name);
199     circle_subgraph_builder.add_data_format(circle::DataFormat_CHANNELS_LAST);
200
201     auto circle_subgraph = circle_subgraph_builder.Finish();
202     subgprahs_vec.emplace_back(circle_subgraph);
203   }
204   _circle_flatbuffer_vec_offset = fb->CreateVector(subgprahs_vec);
205 }
206
207 template <>
208 Offset<OperatorCodeLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
209 {
210   std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
211
212   for (auto it : *tflite_flatbuffer_vec)
213   {
214     auto custom_code = fb->CreateString(it->custom_code());
215     circle::OperatorCodeBuilder operator_code_builder{*fb};
216     operator_code_builder.add_builtin_code(get_circle_builtin_code(it->builtin_code()));
217     operator_code_builder.add_custom_code(custom_code);
218     operator_code_builder.add_version(it->version());
219     auto code = operator_code_builder.Finish();
220     operator_code_vec.emplace_back(code);
221   }
222   _circle_flatbuffer_vec_offset = fb->CreateVector(operator_code_vec);
223 }
224
225 CircleModel::CircleModel(FlatBufBuilder &fb, TFLModel &model)
226     : _version{0}, _description{fb->CreateString("nnpackage")}, _fb{fb}
227 {
228   const tflite::Model *tfl_model = model.load_model();
229   _operator_codes_offset =
230       std::make_unique<Offset<OperatorCodeLink>>(fb, tfl_model->operator_codes());
231   _subGraphs_offset = std::make_unique<Offset<SubGraphLink>>(fb, tfl_model->subgraphs());
232   _buffers_offset = std::make_unique<Offset<BufferLink>>(fb, tfl_model->buffers());
233   _metadata_buffer_offset =
234       std::make_unique<Offset<MetaDataBufferLink>>(fb, tfl_model->metadata_buffer());
235   model_build();
236 }
237
238 void CircleModel::model_build(void) const
239 {
240   circle::ModelBuilder model_builder{*_fb};
241
242   model_builder.add_version(_version);
243   model_builder.add_description(_description);
244   model_builder.add_operator_codes(_operator_codes_offset->offset());
245   model_builder.add_subgraphs(_subGraphs_offset->offset());
246   model_builder.add_buffers(_buffers_offset->offset());
247   model_builder.add_metadata_buffer(_metadata_buffer_offset->offset());
248
249   auto model = model_builder.Finish();
250   circle::FinishModelBuffer(*_fb, model);
251 }
252
253 const char *CircleModel::base(void) const
254 {
255   return reinterpret_cast<const char *>(_fb->GetBufferPointer());
256 }
257
258 size_t CircleModel::size(void) const { return _fb->GetSize(); }
259
260 } // namespace tflite2circle