2203f5906b06f1b268694c2d6eb5ff5cde67e6c6
[platform/core/ml/nnfw.git] / compiler / tflchef / tflite / src / RecipeChef.cpp
1 /*
2  * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <tflchef/RecipeChef.h>
18 #include <mio_tflite2121/Helper.h>
19
20 #include "Convert.h"
21 #include "TFliteImport.h"
22 #include "TFliteOpChef.h"
23 #include "TFliteOpChefs.h"
24 #include "TFliteOpRegistry.h"
25
26 #include <fstream>
27 #include <sstream>
28
29 namespace tflchef
30 {
31
32 void set_inputs(TFliteImport *import, tflchef::Operation *operation, const tflite::Operator *op)
33 {
34   auto tensors = import->tensors();
35   const std::vector<int32_t> &inputs = as_index_vector(op->inputs());
36
37   for (auto input : inputs)
38   {
39     if (input == -1)
40     {
41       operation->add_input("");
42     }
43     else
44     {
45       auto tensor = tensors->Get(input);
46       std::string name = mio::tflite::tensor_name(tensor);
47       operation->add_input(name);
48     }
49   }
50 }
51
52 void set_outputs(TFliteImport *import, tflchef::Operation *operation, const tflite::Operator *op)
53 {
54   auto tensors = import->tensors();
55   const std::vector<int32_t> &outputs = as_index_vector(op->outputs());
56
57   for (auto output : outputs)
58   {
59     auto tensor = tensors->Get(output);
60     std::string name = mio::tflite::tensor_name(tensor);
61     operation->add_output(name);
62   }
63 }
64
65 /**
66  * @brief This will build ModelRecipe from tflite::Model
67  *        First to check operand filler options by scanning all operators,
68  *        then translate all operands and operators.
69  *        Last will set network inputs and outputs.
70  */
71 std::unique_ptr<ModelRecipe> generate_recipe(const tflite::Model *model)
72 {
73   std::unique_ptr<ModelRecipe> model_recipe{new ModelRecipe()};
74
75   TFliteImport tflite_import(model);
76
77   assert(tflite_import.num_subgraph() == 1);
78   tflite_import.select_sub_graph(0);
79
80   auto tensors = tflite_import.tensors();
81   auto buffers = tflite_import.buffers();
82   auto operators = tflite_import.operators();
83
84   // operand fillers for adding all operators
85   for (uint32_t i = 0; i < operators->Length(); ++i)
86   {
87     const auto *op = operators->Get(i);
88     tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
89
90     if (const auto *graph_builder = TFliteOpRegistry::get().lookup(builtincode))
91     {
92       graph_builder->filler(op, &tflite_import, model_recipe.get());
93     }
94     else
95     {
96       std::string opcodename = tflite_import.opcode_name(op);
97       throw std::runtime_error{"Not supported: " + opcodename};
98     }
99   }
100
101   // add all operands(tensors)
102   for (uint32_t i = 0; i < tensors->Length(); ++i)
103   {
104     auto tensor = tensors->Get(i);
105
106     // check buffer
107     if (tensor->buffer() >= buffers->size())
108       throw std::runtime_error{"file load failed"};
109
110     ::tflchef::Operand *operand = model_recipe->add_operand();
111
112     operand->set_name(mio::tflite::tensor_name(tensor));
113     operand->set_type(as_tflchef_type(tensor->type()));
114     operand->set_is_variable(tensor->is_variable());
115
116     if (tensor->shape())
117     {
118       std::vector<int32_t> dims = as_index_vector(tensor->shape());
119       ::tflchef::TensorShape *shape = operand->mutable_shape();
120       for (auto dim : dims)
121       {
122         shape->add_dim(dim);
123       }
124     }
125
126     // filler for weights, bias and so on
127     std::vector<int32_t> expvalues;
128     std::vector<float> expfvalues;
129     if (tflite_import.get_tensor_filler(i))
130     {
131       tflchef::TensorFiller *filler = operand->mutable_filler();
132       // Note: it is OK to use random weights for functionality validation
133       filler->set_tag("gaussian");
134       filler->add_arg("0.0"); // average
135       filler->add_arg("0.1"); // standard deviation
136     }
137     else if (tflite_import.get_tensor_filler(i, expvalues))
138     {
139       tflchef::TensorFiller *filler = operand->mutable_filler();
140       filler->set_tag("explicit");
141       for (auto value : expvalues)
142       {
143         std::ostringstream ss;
144         ss << value;
145         filler->add_arg(ss.str());
146       }
147     }
148     else if (tflite_import.get_tensor_filler(i, expfvalues))
149     {
150       tflchef::TensorFiller *filler = operand->mutable_filler();
151       filler->set_tag("explicit");
152       for (auto value : expfvalues)
153       {
154         std::ostringstream ss;
155         ss << value;
156         filler->add_arg(ss.str());
157       }
158     }
159
160     auto quant = tensor->quantization();
161     if (quant != nullptr)
162     {
163       // Note: Calling 'operand->mutable_quant()' will create empty 'quant' node
164       // in the recipe file. We want this only when valid parameter exist.
165       if (quant->min() != nullptr && quant->min()->size() > 0)
166       {
167         tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
168         for (uint32_t idx = 0; idx < quant->min()->size(); ++idx)
169           chef_quant->add_min(quant->min()->Get(idx));
170       }
171       if (quant->max() != nullptr && quant->max()->size() > 0)
172       {
173         tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
174         for (uint32_t idx = 0; idx < quant->max()->size(); idx++)
175           chef_quant->add_max(quant->max()->Get(idx));
176       }
177       if (quant->scale() != nullptr && quant->scale()->size() > 0)
178       {
179         tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
180         for (uint32_t idx = 0; idx < quant->scale()->size(); ++idx)
181           chef_quant->add_scale(quant->scale()->Get(idx));
182       }
183       if (quant->zero_point() != nullptr && quant->zero_point()->size() > 0)
184       {
185         tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
186         for (uint32_t idx = 0; idx < quant->zero_point()->size(); ++idx)
187           chef_quant->add_zero_point(quant->zero_point()->Get(idx));
188       }
189       tflchef::TensorQuantization *chef_quant = operand->mutable_quant();
190       chef_quant->set_quantized_dimension(quant->quantized_dimension());
191     }
192
193     auto sparsity = tensor->sparsity();
194     if (sparsity != nullptr)
195     {
196       tflchef::TensorSparsity *chef_sparsity = operand->mutable_sparsity();
197       // traversal_order
198       auto chef_traversal_order = chef_sparsity->mutable_traversal_order();
199       for (const auto &to : *(sparsity->traversal_order()))
200       {
201         chef_traversal_order->add_dim(to);
202       }
203       // block_map
204       auto chef_block_map = chef_sparsity->mutable_block_map();
205       for (const auto &bm : *(sparsity->block_map()))
206       {
207         chef_block_map->add_dim(bm);
208       }
209       // dim_metadata
210       for (const auto &dm : *(sparsity->dim_metadata()))
211       {
212         auto chef_dm = chef_sparsity->add_dim_metadata();
213         // format
214         chef_dm->set_format(as_tflchef_sparse_dim_type(dm->format()));
215         // dense_size
216         chef_dm->set_dense_size(dm->dense_size());
217         // array_segments
218         auto chef_array_segments = chef_dm->mutable_array_segments();
219         switch (dm->array_segments_type())
220         {
221           case tflite::SparseIndexVector_NONE:
222             // DO NOTHING
223             break;
224           case tflite::SparseIndexVector_Int32Vector:
225             for (const auto &as : *(dm->array_segments_as_Int32Vector()->values()))
226             {
227               chef_array_segments->add_dim(as);
228             }
229             break;
230           case tflite::SparseIndexVector_Uint16Vector:
231             for (const auto &as : *(dm->array_segments_as_Uint16Vector()->values()))
232             {
233               chef_array_segments->add_dim(as);
234             }
235             break;
236           case tflite::SparseIndexVector_Uint8Vector:
237             for (const auto &as : *(dm->array_segments_as_Uint8Vector()->values()))
238             {
239               chef_array_segments->add_dim(as);
240             }
241             break;
242           default:
243             throw std::runtime_error("unsupported sparse index vector type");
244         }
245         // array_indices
246         auto chef_array_indices = chef_dm->mutable_array_indices();
247         switch (dm->array_indices_type())
248         {
249           case tflite::SparseIndexVector_NONE:
250             // DO NOTHING
251             break;
252           case tflite::SparseIndexVector_Int32Vector:
253             for (const auto &as : *(dm->array_indices_as_Int32Vector()->values()))
254             {
255               chef_array_indices->add_dim(as);
256             }
257             break;
258           case tflite::SparseIndexVector_Uint16Vector:
259             for (const auto &as : *(dm->array_indices_as_Uint16Vector()->values()))
260             {
261               chef_array_indices->add_dim(as);
262             }
263             break;
264           case tflite::SparseIndexVector_Uint8Vector:
265             for (const auto &as : *(dm->array_indices_as_Uint8Vector()->values()))
266             {
267               chef_array_indices->add_dim(as);
268             }
269             break;
270           default:
271             throw std::runtime_error("unsupported sparse index vector type");
272         }
273       }
274     }
275
276     auto shape_signature = tensor->shape_signature();
277     if (shape_signature != nullptr)
278     {
279       tflchef::ShapeSignature *chef_shape_signature = operand->mutable_shape_signature();
280       for (uint32_t i = 0; i < shape_signature->size(); ++i)
281       {
282         chef_shape_signature->add_dim(shape_signature->Get(i));
283       }
284     }
285   }
286
287   // add all operators
288   for (uint32_t i = 0; i < operators->Length(); ++i)
289   {
290     const auto *op = operators->Get(i);
291     tflite::BuiltinOperator builtincode = tflite_import.builtin_code(op);
292
293     if (const auto *graph_builder = TFliteOpRegistry::get().lookup(builtincode))
294     {
295       auto operation = graph_builder->build(op, &tflite_import, model_recipe.get());
296
297       // common for all operators: inputs, outputs
298       set_inputs(&tflite_import, operation, op);
299       set_outputs(&tflite_import, operation, op);
300     }
301     else
302     {
303       std::string opcodename = tflite_import.opcode_name(op);
304       throw std::runtime_error{"Not supported: " + opcodename};
305     }
306   }
307
308   // network inputs/outputs
309   const std::vector<int32_t> &inputs = tflite_import.inputs();
310   const std::vector<int32_t> &outputs = tflite_import.outputs();
311
312   for (const auto input : inputs)
313   {
314     auto tensor = tensors->Get(input);
315     std::string name = mio::tflite::tensor_name(tensor);
316
317     model_recipe->add_input(name);
318   }
319   for (const auto output : outputs)
320   {
321     auto tensor = tensors->Get(output);
322     std::string name = mio::tflite::tensor_name(tensor);
323
324     model_recipe->add_output(name);
325   }
326
327   return std::move(model_recipe);
328 }
329
330 bool write_recipe(const std::string &filename, std::unique_ptr<ModelRecipe> &recipe)
331 {
332   std::fstream fo(filename, std::ios::binary | std::ios::out);
333
334   if (!fo.is_open())
335   {
336     throw std::runtime_error{"file store failed"};
337   }
338
339   // Note: SerializeToString() or SerializeToOstream() writes in binary mode
340   // DebugString() and Utf8DebugString() will print as a human readable text
341   fo << recipe->Utf8DebugString();
342
343   fo.close();
344
345   return true;
346 }
347
348 } // namespace tflchef