2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <tflchef/RecipeChef.h>
20 #include "TFliteImport.h"
21 #include "TFliteOpChef.h"
22 #include "TFliteOpChefs.h"
23 #include "TFliteOpRegistry.h"
31 void set_inputs(TFliteImport *import, tflchef::Operation *operation, const tflite::Operator *op)
33 auto tensors = import->tensors();
34 const std::vector<int32_t> &inputs = as_index_vector(op->inputs());
36 for (auto input : inputs)
40 operation->add_input("");
44 auto tensor = tensors->Get(input);
45 std::string name = tensor_name(tensor);
46 operation->add_input(name);
51 void set_outputs(TFliteImport *import, tflchef::Operation *operation, const tflite::Operator *op)
53 auto tensors = import->tensors();
54 const std::vector<int32_t> &outputs = as_index_vector(op->outputs());
56 for (auto output : outputs)
58 auto tensor = tensors->Get(output);
59 std::string name = tensor_name(tensor);
60 operation->add_output(name);
65 * @brief This will build ModelRecipe from tflite::Model
66 * First to check operand filler options by scanning all operators,
67 * then translate all operands and operators.
68 * Last will set network inputs and outputs.
70 std::unique_ptr<ModelRecipe> generate_recipe(const tflite::Model *model)
72 std::unique_ptr<ModelRecipe> model_recipe{new ModelRecipe()};
74 TFliteImport tflite_import(model);
76 assert(tflite_import.num_subgraph() == 1);
77 tflite_import.select_sub_graph(0);
79 auto tensors = tflite_import.tensors();
80 auto buffers = tflite_import.buffers();
81 auto operators = tflite_import.operators();
83 // operand fillers for adding all operators
84 for (uint32_t i = 0; i < operators->Length(); ++i)
86 const auto *op = operators->Get(i);
87 tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
89 if (const auto *graph_builder = TFliteOpRegistry::get().lookup(builtincode))
91 graph_builder->filler(op, &tflite_import, model_recipe.get());
95 std::string opcodename = tflite_import.opcode_name(op);
96 throw std::runtime_error{"Not supported: " + opcodename};
100 // add all operands(tensors)
101 for (uint32_t i = 0; i < tensors->Length(); ++i)
103 auto tensor = tensors->Get(i);
106 if (tensor->buffer() >= buffers->size())
107 throw std::runtime_error{"file load failed"};
109 ::tflchef::Operand *operand = model_recipe->add_operand();
111 operand->set_name(tensor_name(tensor));
112 operand->set_type(as_tflchef_type(tensor->type()));
116 std::vector<int32_t> dims = as_index_vector(tensor->shape());
117 ::tflchef::TensorShape *shape = operand->mutable_shape();
118 for (auto dim : dims)
124 // filler for weights, bias and so on
125 std::vector<int32_t> expvalues;
126 std::vector<float> expfvalues;
127 if (tflite_import.get_tensor_filler(i))
129 tflchef::TensorFiller *filler = operand->mutable_filler();
130 // Note: it is OK to use random weights for functionality validation
131 filler->set_tag("gaussian");
132 filler->add_arg("0.0"); // average
133 filler->add_arg("0.1"); // standard deviation
135 else if (tflite_import.get_tensor_filler(i, expvalues))
137 tflchef::TensorFiller *filler = operand->mutable_filler();
138 filler->set_tag("explicit");
139 for (auto value : expvalues)
141 std::ostringstream ss;
143 filler->add_arg(ss.str());
146 else if (tflite_import.get_tensor_filler(i, expfvalues))
148 tflchef::TensorFiller *filler = operand->mutable_filler();
149 filler->set_tag("explicit");
150 for (auto value : expfvalues)
152 std::ostringstream ss;
154 filler->add_arg(ss.str());
158 auto quant = tensor->quantization();
159 if (quant != nullptr)
161 // Note: Calling 'operand->mutable_quant()' will create empty 'quant' node
162 // in the recipe file. We want this only when valid parameter exist.
163 if (quant->min() != nullptr && quant->min()->size() > 0)
165 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
166 for (uint32_t idx = 0; idx < quant->min()->size(); ++idx)
167 chef_quant->add_min(quant->min()->Get(idx));
169 if (quant->max() != nullptr && quant->max()->size() > 0)
171 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
172 for (uint32_t idx = 0; idx < quant->max()->size(); idx++)
173 chef_quant->add_max(quant->max()->Get(idx));
175 if (quant->scale() != nullptr && quant->scale()->size() > 0)
177 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
178 for (uint32_t idx = 0; idx < quant->scale()->size(); ++idx)
179 chef_quant->add_scale(quant->scale()->Get(idx));
181 if (quant->zero_point() != nullptr && quant->zero_point()->size() > 0)
183 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
184 for (uint32_t idx = 0; idx < quant->zero_point()->size(); ++idx)
185 chef_quant->add_zero_point(quant->zero_point()->Get(idx));
187 tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
188 chef_quant->set_quantized_dimension(quant->quantized_dimension());
193 for (uint32_t i = 0; i < operators->Length(); ++i)
195 const auto *op = operators->Get(i);
196 tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
198 if (const auto *graph_builder = TFliteOpRegistry::get().lookup(builtincode))
200 auto operation = graph_builder->build(op, &tflite_import, model_recipe.get());
202 // common for all operators: inputs, outputs
203 set_inputs(&tflite_import, operation, op);
204 set_outputs(&tflite_import, operation, op);
208 std::string opcodename = tflite_import.opcode_name(op);
209 throw std::runtime_error{"Not supported: " + opcodename};
213 // network inputs/outputs
214 const std::vector<int32_t> &inputs = tflite_import.inputs();
215 const std::vector<int32_t> &outputs = tflite_import.outputs();
217 for (const auto input : inputs)
219 auto tensor = tensors->Get(input);
220 std::string name = tensor_name(tensor);
222 model_recipe->add_input(name);
224 for (const auto output : outputs)
226 auto tensor = tensors->Get(output);
227 std::string name = tensor_name(tensor);
229 model_recipe->add_output(name);
232 return std::move(model_recipe);
235 bool write_recipe(const std::string &filename, std::unique_ptr<ModelRecipe> &recipe)
237 std::fstream fo(filename, std::ios::binary | std::ios::out);
241 throw std::runtime_error{"file store failed"};
244 // Note: SerializeToString() or SerializeToOstream() writes in binary mode
245 // DebugString() and Utf8DebugString() will print as a human readable text
246 fo << recipe->Utf8DebugString();
253 } // namespace tflchef