2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
20 #include "CircleModel.h"
21 #include "DataLookup.h"
23 namespace tflite2circle
27 Offset<MetaDataBufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
29 if (tflite_flatbuffer_vec == nullptr)
31 std::vector<int32_t> metadata_buffer_vec{tflite_flatbuffer_vec->begin(),
32 tflite_flatbuffer_vec->end()};
33 _circle_flatbuffer_vec_offset = fb->CreateVector(metadata_buffer_vec);
37 Offset<BufferLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
39 std::vector<flatbuffers::Offset<circle::Buffer>> buffers_vec;
41 for (auto it : *tflite_flatbuffer_vec)
43 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> buffer_data;
46 std::vector<uint8_t> data_vec{it->data()->begin(), it->data()->end()};
47 buffer_data = fb->CreateVector(data_vec);
49 circle::BufferBuilder circle_buffer_builder{*fb};
50 circle_buffer_builder.add_data(buffer_data);
51 auto circle_buffers = circle_buffer_builder.Finish();
52 buffers_vec.emplace_back(circle_buffers);
54 _circle_flatbuffer_vec_offset = fb->CreateVector(buffers_vec);
58 Offset<SubGraphLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
60 std::vector<flatbuffers::Offset<circle::SubGraph>> subgprahs_vec;
62 for (auto it_sg : *tflite_flatbuffer_vec)
64 // tensors of subgraph
65 std::vector<flatbuffers::Offset<circle::Tensor>> tensor_vec;
67 auto tflite_tensors = it_sg->tensors();
68 for (auto it : *tflite_tensors)
71 flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape;
74 auto shape_vec = std::vector<int32_t>({it->shape()->begin(), it->shape()->end()});
75 shape = fb->CreateVector(shape_vec);
78 flatbuffers::Offset<flatbuffers::String> name;
80 name = fb->CreateString(it->name()->str());
82 flatbuffers::Offset<circle::QuantizationParameters> quantization;
83 if (it->quantization())
85 std::vector<float> tfmin;
86 std::vector<float> tfmax;
87 std::vector<float> tfscale;
88 std::vector<int64_t> tfzerop;
89 flatbuffers::Offset<flatbuffers::Vector<float>> min;
90 flatbuffers::Offset<flatbuffers::Vector<float>> max;
91 flatbuffers::Offset<flatbuffers::Vector<float>> scale;
92 flatbuffers::Offset<flatbuffers::Vector<int64_t>> zero_point;
93 int32_t quantized_dimension = it->quantization()->quantized_dimension();
95 if (it->quantization()->min() && it->quantization()->max())
97 auto rmin = it->quantization()->min();
98 auto rmax = it->quantization()->max();
99 tfmin = std::vector<float>{rmin->begin(), rmin->end()};
100 tfmax = std::vector<float>{rmax->begin(), rmax->end()};
101 min = fb->CreateVector(tfmin);
102 max = fb->CreateVector(tfmax);
105 if (it->quantization()->scale() && it->quantization()->zero_point())
107 auto rs = it->quantization()->scale();
108 auto rz = it->quantization()->zero_point();
109 tfscale = std::vector<float>{rs->begin(), rs->end()};
110 tfzerop = std::vector<int64_t>{rz->begin(), rz->end()};
111 scale = fb->CreateVector(tfscale);
112 zero_point = fb->CreateVector(tfzerop);
115 quantization = circle::CreateQuantizationParameters(*fb, min, max, scale, zero_point,
116 circle::QuantizationDetails_NONE, 0,
117 quantized_dimension);
120 bool is_variable = it->is_variable();
122 circle::TensorBuilder tensor_builder{*fb};
123 tensor_builder.add_shape(shape);
124 tensor_builder.add_type(get_circle_tensortype(it->type()));
125 tensor_builder.add_buffer(it->buffer());
126 tensor_builder.add_name(name);
127 tensor_builder.add_quantization(quantization);
128 tensor_builder.add_is_variable(is_variable);
129 auto tensor = tensor_builder.Finish();
130 tensor_vec.emplace_back(tensor);
132 auto circle_tensors = fb->CreateVector(tensor_vec);
134 // inputs of subgraph
135 auto tflite_inputs = it_sg->inputs();
136 std::vector<int32_t> input_vec{tflite_inputs->begin(), tflite_inputs->end()};
138 auto circle_inputs = fb->CreateVector(input_vec);
140 // outputs of subgraph
141 auto tflite_outputs = it_sg->outputs();
142 std::vector<int32_t> output_vec{tflite_outputs->begin(), tflite_outputs->end()};
144 auto circle_outputs = fb->CreateVector(output_vec);
146 // operators of subgraph
147 std::vector<flatbuffers::Offset<circle::Operator>> operator_vec;
149 auto tflite_operators = it_sg->operators();
150 for (auto it : *tflite_operators)
153 std::vector<int32_t> input_vec{it->inputs()->begin(), it->inputs()->end()};
154 auto circle_inputs = fb->CreateVector(input_vec);
156 std::vector<int32_t> output_vec{it->outputs()->begin(), it->outputs()->end()};
157 auto circle_outputs = fb->CreateVector(output_vec);
159 auto circle_builtin_options = get_circle_builtin_options(*fb, it);
160 auto circle_builtin_options_type = get_circle_builtin_options_type(it);
162 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> circle_custom_options;
163 if (it->custom_options())
165 std::vector<uint8_t> custom_options_vec{it->custom_options()->begin(),
166 it->custom_options()->end()};
167 circle_custom_options = fb->CreateVector(custom_options_vec);
169 // custom options format
170 // TODO Make get_circle_custom_options_format
171 assert(it->custom_options_format() == tflite::CustomOptionsFormat_FLEXBUFFERS);
172 auto circle_custom_options_format = circle::CustomOptionsFormat_FLEXBUFFERS;
174 circle::OperatorBuilder operator_builder{*fb};
175 operator_builder.add_opcode_index(it->opcode_index());
176 operator_builder.add_inputs(circle_inputs);
177 operator_builder.add_outputs(circle_outputs);
178 operator_builder.add_builtin_options(circle_builtin_options);
179 operator_builder.add_builtin_options_type(circle_builtin_options_type);
180 operator_builder.add_custom_options(circle_custom_options);
181 operator_builder.add_custom_options_format(circle_custom_options_format);
182 // TODO mutating_variable_inputs
183 auto opeartor = operator_builder.Finish();
184 operator_vec.emplace_back(opeartor);
186 auto circle_operators = fb->CreateVector(operator_vec);
189 auto subgraphs_name = fb->CreateString(it_sg->name());
192 auto circle_subgraph_builder = circle::SubGraphBuilder{*fb};
194 circle_subgraph_builder.add_tensors(circle_tensors);
195 circle_subgraph_builder.add_inputs(circle_inputs);
196 circle_subgraph_builder.add_outputs(circle_outputs);
197 circle_subgraph_builder.add_operators(circle_operators);
198 circle_subgraph_builder.add_name(subgraphs_name);
199 circle_subgraph_builder.add_data_format(circle::DataFormat_CHANNELS_LAST);
201 auto circle_subgraph = circle_subgraph_builder.Finish();
202 subgprahs_vec.emplace_back(circle_subgraph);
204 _circle_flatbuffer_vec_offset = fb->CreateVector(subgprahs_vec);
208 Offset<OperatorCodeLink>::Offset(FlatBufBuilder &fb, const TFLFlatBufVec *tflite_flatbuffer_vec)
210 std::vector<flatbuffers::Offset<circle::OperatorCode>> operator_code_vec;
212 for (auto it : *tflite_flatbuffer_vec)
214 auto custom_code = fb->CreateString(it->custom_code());
215 circle::OperatorCodeBuilder operator_code_builder{*fb};
216 operator_code_builder.add_builtin_code(get_circle_builtin_code(it->builtin_code()));
217 operator_code_builder.add_custom_code(custom_code);
218 operator_code_builder.add_version(it->version());
219 auto code = operator_code_builder.Finish();
220 operator_code_vec.emplace_back(code);
222 _circle_flatbuffer_vec_offset = fb->CreateVector(operator_code_vec);
225 CircleModel::CircleModel(FlatBufBuilder &fb, TFLModel &model)
226 : _version{0}, _description{fb->CreateString("nnpackage")}, _fb{fb}
228 const tflite::Model *tfl_model = model.load_model();
229 _operator_codes_offset =
230 std::make_unique<Offset<OperatorCodeLink>>(fb, tfl_model->operator_codes());
231 _subGraphs_offset = std::make_unique<Offset<SubGraphLink>>(fb, tfl_model->subgraphs());
232 _buffers_offset = std::make_unique<Offset<BufferLink>>(fb, tfl_model->buffers());
233 _metadata_buffer_offset =
234 std::make_unique<Offset<MetaDataBufferLink>>(fb, tfl_model->metadata_buffer());
238 void CircleModel::model_build(void) const
240 circle::ModelBuilder model_builder{*_fb};
242 model_builder.add_version(_version);
243 model_builder.add_description(_description);
244 model_builder.add_operator_codes(_operator_codes_offset->offset());
245 model_builder.add_subgraphs(_subGraphs_offset->offset());
246 model_builder.add_buffers(_buffers_offset->offset());
247 model_builder.add_metadata_buffer(_metadata_buffer_offset->offset());
249 auto model = model_builder.Finish();
250 circle::FinishModelBuffer(*_fb, model);
253 const char *CircleModel::base(void) const
255 return reinterpret_cast<const char *>(_fb->GetBufferPointer());
258 size_t CircleModel::size(void) const { return _fb->GetSize(); }
260 } // namespace tflite2circle