2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <tflchef/RecipeChef.h>
18 #include <mio_tflite2121/Helper.h>
21 #include "TFliteImport.h"
22 #include "TFliteOpChef.h"
23 #include "TFliteOpChefs.h"
24 #include "TFliteOpRegistry.h"
32 void set_inputs(TFliteImport *import, tflchef::Operation *operation, const tflite::Operator *op)
34 auto tensors = import->tensors();
35 const std::vector<int32_t> &inputs = as_index_vector(op->inputs());
37 for (auto input : inputs)
41 operation->add_input("");
45 auto tensor = tensors->Get(input);
46 std::string name = mio::tflite::tensor_name(tensor);
47 operation->add_input(name);
52 void set_outputs(TFliteImport *import, tflchef::Operation *operation, const tflite::Operator *op)
54 auto tensors = import->tensors();
55 const std::vector<int32_t> &outputs = as_index_vector(op->outputs());
57 for (auto output : outputs)
59 auto tensor = tensors->Get(output);
60 std::string name = mio::tflite::tensor_name(tensor);
61 operation->add_output(name);
66 * @brief This will build ModelRecipe from tflite::Model
67 * First to check operand filler options by scanning all operators,
68 * then translate all operands and operators.
69 * Last will set network inputs and outputs.
71 std::unique_ptr<ModelRecipe> generate_recipe(const tflite::Model *model)
73 std::unique_ptr<ModelRecipe> model_recipe{new ModelRecipe()};
75 TFliteImport tflite_import(model);
77 assert(tflite_import.num_subgraph() == 1);
78 tflite_import.select_sub_graph(0);
80 auto tensors = tflite_import.tensors();
81 auto buffers = tflite_import.buffers();
82 auto operators = tflite_import.operators();
84 // operand fillers for adding all operators
85 for (uint32_t i = 0; i < operators->Length(); ++i)
87 const auto *op = operators->Get(i);
88 tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
90 if (const auto *graph_builder = TFliteOpRegistry::get().lookup(builtincode))
92 graph_builder->filler(op, &tflite_import, model_recipe.get());
96 std::string opcodename = tflite_import.opcode_name(op);
97 throw std::runtime_error{"Not supported: " + opcodename};
101 // add all operands(tensors)
102 for (uint32_t i = 0; i < tensors->Length(); ++i)
104 auto tensor = tensors->Get(i);
107 if (tensor->buffer() >= buffers->size())
108 throw std::runtime_error{"file load failed"};
110 ::tflchef::Operand *operand = model_recipe->add_operand();
112 operand->set_name(mio::tflite::tensor_name(tensor));
113 operand->set_type(as_tflchef_type(tensor->type()));
114 operand->set_is_variable(tensor->is_variable());
118 std::vector<int32_t> dims = as_index_vector(tensor->shape());
119 ::tflchef::TensorShape *shape = operand->mutable_shape();
120 for (auto dim : dims)
126 // filler for weights, bias and so on
127 std::vector<int32_t> expvalues;
128 std::vector<float> expfvalues;
129 if (tflite_import.get_tensor_filler(i))
131 tflchef::TensorFiller *filler = operand->mutable_filler();
132 // Note: it is OK to use random weights for functionality validation
133 filler->set_tag("gaussian");
134 filler->add_arg("0.0"); // average
135 filler->add_arg("0.1"); // standard deviation
137 else if (tflite_import.get_tensor_filler(i, expvalues))
139 tflchef::TensorFiller *filler = operand->mutable_filler();
140 filler->set_tag("explicit");
141 for (auto value : expvalues)
143 std::ostringstream ss;
145 filler->add_arg(ss.str());
148 else if (tflite_import.get_tensor_filler(i, expfvalues))
150 tflchef::TensorFiller *filler = operand->mutable_filler();
151 filler->set_tag("explicit");
152 for (auto value : expfvalues)
154 std::ostringstream ss;
156 filler->add_arg(ss.str());
160 auto quant = tensor->quantization();
161 if (quant != nullptr)
163 // Note: Calling 'operand->mutable_quant()' will create empty 'quant' node
164 // in the recipe file. We want this only when valid parameter exist.
165 if (quant->min() != nullptr && quant->min()->size() > 0)
167 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
168 for (uint32_t idx = 0; idx < quant->min()->size(); ++idx)
169 chef_quant->add_min(quant->min()->Get(idx));
171 if (quant->max() != nullptr && quant->max()->size() > 0)
173 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
174 for (uint32_t idx = 0; idx < quant->max()->size(); idx++)
175 chef_quant->add_max(quant->max()->Get(idx));
177 if (quant->scale() != nullptr && quant->scale()->size() > 0)
179 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
180 for (uint32_t idx = 0; idx < quant->scale()->size(); ++idx)
181 chef_quant->add_scale(quant->scale()->Get(idx));
183 if (quant->zero_point() != nullptr && quant->zero_point()->size() > 0)
185 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
186 for (uint32_t idx = 0; idx < quant->zero_point()->size(); ++idx)
187 chef_quant->add_zero_point(quant->zero_point()->Get(idx));
189 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
190 chef_quant->set_quantized_dimension(quant->quantized_dimension());
193 auto sparsity = tensor->sparsity();
194 if (sparsity != nullptr)
196 tflchef::TensorSparsity *chef_sparsity = operand->mutable_sparsity();
198 auto chef_traversal_order = chef_sparsity->mutable_traversal_order();
199 for (const auto &to : *(sparsity->traversal_order()))
201 chef_traversal_order->add_dim(to);
204 auto chef_block_map = chef_sparsity->mutable_block_map();
205 for (const auto &bm : *(sparsity->block_map()))
207 chef_block_map->add_dim(bm);
210 for (const auto &dm : *(sparsity->dim_metadata()))
212 auto chef_dm = chef_sparsity->add_dim_metadata();
214 chef_dm->set_format(as_tflchef_sparse_dim_type(dm->format()));
216 chef_dm->set_dense_size(dm->dense_size());
218 auto chef_array_segments = chef_dm->mutable_array_segments();
219 switch (dm->array_segments_type())
221 case tflite::SparseIndexVector_NONE:
224 case tflite::SparseIndexVector_Int32Vector:
225 for (const auto &as : *(dm->array_segments_as_Int32Vector()->values()))
227 chef_array_segments->add_dim(as);
230 case tflite::SparseIndexVector_Uint16Vector:
231 for (const auto &as : *(dm->array_segments_as_Uint16Vector()->values()))
233 chef_array_segments->add_dim(as);
236 case tflite::SparseIndexVector_Uint8Vector:
237 for (const auto &as : *(dm->array_segments_as_Uint8Vector()->values()))
239 chef_array_segments->add_dim(as);
243 throw std::runtime_error("unsupported sparse index vector type");
246 auto chef_array_indices = chef_dm->mutable_array_indices();
247 switch (dm->array_indices_type())
249 case tflite::SparseIndexVector_NONE:
252 case tflite::SparseIndexVector_Int32Vector:
253 for (const auto &as : *(dm->array_indices_as_Int32Vector()->values()))
255 chef_array_indices->add_dim(as);
258 case tflite::SparseIndexVector_Uint16Vector:
259 for (const auto &as : *(dm->array_indices_as_Uint16Vector()->values()))
261 chef_array_indices->add_dim(as);
264 case tflite::SparseIndexVector_Uint8Vector:
265 for (const auto &as : *(dm->array_indices_as_Uint8Vector()->values()))
267 chef_array_indices->add_dim(as);
271 throw std::runtime_error("unsupported sparse index vector type");
276 auto shape_signature = tensor->shape_signature();
277 if (shape_signature != nullptr)
279 tflchef::ShapeSignature *chef_shape_signature = operand->mutable_shape_signature();
280 for (uint32_t i = 0; i < shape_signature->size(); ++i)
282 chef_shape_signature->add_dim(shape_signature->Get(i));
288 for (uint32_t i = 0; i < operators->Length(); ++i)
290 const auto *op = operators->Get(i);
291 tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
293 if (const auto *graph_builder = TFliteOpRegistry::get().lookup(builtincode))
295 auto operation = graph_builder->build(op, &tflite_import, model_recipe.get());
297 // common for all operators: inputs, outputs
298 set_inputs(&tflite_import, operation, op);
299 set_outputs(&tflite_import, operation, op);
303 std::string opcodename = tflite_import.opcode_name(op);
304 throw std::runtime_error{"Not supported: " + opcodename};
308 // network inputs/outputs
309 const std::vector<int32_t> &inputs = tflite_import.inputs();
310 const std::vector<int32_t> &outputs = tflite_import.outputs();
312 for (const auto input : inputs)
314 auto tensor = tensors->Get(input);
315 std::string name = mio::tflite::tensor_name(tensor);
317 model_recipe->add_input(name);
319 for (const auto output : outputs)
321 auto tensor = tensors->Get(output);
322 std::string name = mio::tflite::tensor_name(tensor);
324 model_recipe->add_output(name);
327 return std::move(model_recipe);
330 bool write_recipe(const std::string &filename, std::unique_ptr<ModelRecipe> &recipe)
332 std::fstream fo(filename, std::ios::binary | std::ios::out);
336 throw std::runtime_error{"file store failed"};
339 // Note: SerializeToString() or SerializeToOstream() writes in binary mode
340 // DebugString() and Utf8DebugString() will print as a human readable text
341 fo << recipe->Utf8DebugString();
348 } // namespace tflchef