2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <circlechef/RecipeChef.h>
20 #include "CircleImport.h"
21 #include "CircleOpChef.h"
22 #include "CircleOpChefs.h"
23 #include "CircleOpRegistry.h"
31 void set_inputs(CircleImport *import, circlechef::Operation *operation, const circle::Operator *op)
33 auto tensors = import->tensors();
34 const std::vector<int32_t> &inputs = as_index_vector(op->inputs());
36 for (auto input : inputs)
40 operation->add_input("");
44 auto tensor = tensors->Get(input);
45 std::string name = tensor_name(tensor);
46 operation->add_input(name);
51 void set_outputs(CircleImport *import, circlechef::Operation *operation, const circle::Operator *op)
53 auto tensors = import->tensors();
54 const std::vector<int32_t> &outputs = as_index_vector(op->outputs());
56 for (auto output : outputs)
58 auto tensor = tensors->Get(output);
59 std::string name = tensor_name(tensor);
60 operation->add_output(name);
65 * @brief This will build ModelRecipe from circle::Model
66 * First to check operand filler options by scanning all operators,
67 * then translate all operands and operators.
68 * Last will set network inputs and outputs.
70 std::unique_ptr<ModelRecipe> generate_recipe(const circle::Model *model)
72 std::unique_ptr<ModelRecipe> model_recipe{new ModelRecipe()};
74 CircleImport circle_import(model);
76 assert(circle_import.num_subgraph() == 1);
77 circle_import.select_sub_graph(0);
79 auto tensors = circle_import.tensors();
80 auto buffers = circle_import.buffers();
81 auto operators = circle_import.operators();
83 // operand fillers for adding all operators
84 for (uint32_t i = 0; i < operators->Length(); ++i)
86 const auto *op = operators->Get(i);
87 circle::BuiltinOperator builtincode = circle_import.builtin_code(op);
89 if (const auto *graph_builder = CircleOpRegistry::get().lookup(builtincode))
91 graph_builder->filler(op, &circle_import, model_recipe.get());
95 std::string opcodename = circle_import.opcode_name(op);
96 throw std::runtime_error{"Not supported: " + opcodename};
100 // add all operands(tensors)
101 for (uint32_t i = 0; i < tensors->Length(); ++i)
103 auto tensor = tensors->Get(i);
106 if (tensor->buffer() >= buffers->size())
107 throw std::runtime_error{"file load failed"};
109 ::circlechef::Operand *operand = model_recipe->add_operand();
111 operand->set_name(tensor_name(tensor));
112 operand->set_type(as_circlechef_type(tensor->type()));
114 std::vector<int32_t> dims = as_index_vector(tensor->shape());
115 ::circlechef::TensorShape *shape = operand->mutable_shape();
116 for (auto dim : dims)
121 // filler for weights, bias and so on
122 std::vector<int32_t> expvalues;
123 std::vector<float> expfvalues;
124 if (circle_import.get_tensor_filler(i))
126 circlechef::TensorFiller *filler = operand->mutable_filler();
127 // Note: it is OK to use random weights for functionality validation
128 filler->set_tag("gaussian");
129 filler->add_arg("0.0"); // average
130 filler->add_arg("0.1"); // standard deviation
132 else if (circle_import.get_tensor_filler(i, expvalues))
134 circlechef::TensorFiller *filler = operand->mutable_filler();
135 filler->set_tag("explicit");
136 for (auto value : expvalues)
138 std::ostringstream ss;
140 filler->add_arg(ss.str());
143 else if (circle_import.get_tensor_filler(i, expfvalues))
145 circlechef::TensorFiller *filler = operand->mutable_filler();
146 filler->set_tag("explicit");
147 for (auto value : expfvalues)
149 std::ostringstream ss;
151 filler->add_arg(ss.str());
155 auto quant = tensor->quantization();
156 if (quant != nullptr)
158 // Note: Calling 'operand->mutable_quant()' will create empty 'quant' node
159 // in the recipe file. We want this only when valid parameter exist.
160 if (quant->min() != nullptr && quant->min()->size() > 0)
162 circlechef::TensorQuantization *chef_quant = operand->mutable_quant();
163 for (uint32_t idx = 0; idx < quant->min()->size(); ++idx)
164 chef_quant->add_min(quant->min()->Get(idx));
166 if (quant->max() != nullptr && quant->max()->size() > 0)
168 circlechef::TensorQuantization *chef_quant = operand->mutable_quant();
169 for (uint32_t idx = 0; idx < quant->max()->size(); idx++)
170 chef_quant->add_max(quant->max()->Get(idx));
172 if (quant->scale() != nullptr && quant->scale()->size() > 0)
174 circlechef::TensorQuantization *chef_quant = operand->mutable_quant();
175 for (uint32_t idx = 0; idx < quant->scale()->size(); ++idx)
176 chef_quant->add_scale(quant->scale()->Get(idx));
178 if (quant->zero_point() != nullptr && quant->zero_point()->size() > 0)
180 circlechef::TensorQuantization *chef_quant = operand->mutable_quant();
181 for (uint32_t idx = 0; idx < quant->zero_point()->size(); ++idx)
182 chef_quant->add_zero_point(quant->zero_point()->Get(idx));
184 circlechef::TensorQuantization *chef_quant = operand->mutable_quant();
185 chef_quant->set_quantized_dimension(quant->quantized_dimension());
190 for (uint32_t i = 0; i < operators->Length(); ++i)
192 const auto *op = operators->Get(i);
193 circle::BuiltinOperator builtincode = circle_import.builtin_code(op);
195 if (const auto *graph_builder = CircleOpRegistry::get().lookup(builtincode))
197 auto operation = graph_builder->build(op, &circle_import, model_recipe.get());
199 // common for all operators: inputs, outputs
200 set_inputs(&circle_import, operation, op);
201 set_outputs(&circle_import, operation, op);
205 std::string opcodename = circle_import.opcode_name(op);
206 throw std::runtime_error{"Not supported: " + opcodename};
210 // network inputs/outputs
211 const std::vector<int32_t> &inputs = circle_import.inputs();
212 const std::vector<int32_t> &outputs = circle_import.outputs();
214 for (const auto input : inputs)
216 auto tensor = tensors->Get(input);
217 std::string name = tensor_name(tensor);
219 model_recipe->add_input(name);
221 for (const auto output : outputs)
223 auto tensor = tensors->Get(output);
224 std::string name = tensor_name(tensor);
226 model_recipe->add_output(name);
229 return std::move(model_recipe);
232 bool write_recipe(const std::string &filename, std::unique_ptr<ModelRecipe> &recipe)
234 std::fstream fo(filename, std::ios::binary | std::ios::out);
238 throw std::runtime_error{"file store failed"};
241 // Note: SerializeToString() or SerializeToOstream() writes in binary mode
242 // DebugString() and Utf8DebugString() will print as a human readable text
243 fo << recipe->Utf8DebugString();
250 } // namespace circlechef