Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / circle2circle / src / Circle2Circle.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <foder/FileLoader.h>
18
19 #include <luci/Importer.h>
20 #include <luci/CircleOptimizer.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24 #include <luci/UserSettings.h>
25
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
29
30 #include <functional>
31 #include <iostream>
32 #include <string>
33
34 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
35 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
36
37 void print_version(void)
38 {
39   std::cout << "circle2circle version " << vconone::get_string() << std::endl;
40   std::cout << vconone::get_copyright() << std::endl;
41 }
42
43 int entry(int argc, char **argv)
44 {
45   // Simple argument parser (based on map)
46   luci::CircleOptimizer optimizer;
47
48   auto options = optimizer.options();
49   auto settings = luci::UserSettings::settings();
50
51   arser::Arser arser("circle2circle provides circle model optimization and transformations");
52
53   arser.add_argument("--version")
54       .nargs(0)
55       .required(false)
56       .default_value(false)
57       .help("Show version information and exit")
58       .exit_with(print_version);
59
60   arser.add_argument("--all").nargs(0).required(false).default_value(false).help(
61       "Enable all optimize options");
62
63   arser.add_argument("--fold_dequantize")
64       .nargs(0)
65       .required(false)
66       .default_value(false)
67       .help("This will fold dequantize op");
68
69   arser.add_argument("--fuse_activation_function")
70       .nargs(0)
71       .required(false)
72       .default_value(false)
73       .help("This will fuse Activation function to a preceding operator");
74
75   arser.add_argument("--fuse_add_with_tconv")
76       .nargs(0)
77       .required(false)
78       .default_value(false)
79       .help("This will fuse Add operator to Transposed Convolution operator");
80
81   arser.add_argument("--fuse_batchnorm_with_tconv")
82       .nargs(0)
83       .required(false)
84       .default_value(false)
85       .help("This will fuse BatchNorm operators to Transposed Convolution operator");
86
87   arser.add_argument("--fuse_bcq")
88       .nargs(0)
89       .required(false)
90       .default_value(false)
91       .help("This will fuse operators and apply Binary Coded Quantization");
92
93   arser.add_argument("--fuse_instnorm")
94       .nargs(0)
95       .required(false)
96       .default_value(false)
97       .help("This will fuse operators to InstanceNorm operator");
98
99   arser.add_argument("--make_batchnorm_gamma_positive")
100       .nargs(0)
101       .required(false)
102       .default_value(false)
103       .help("This will make negative gamma of BatchNorm into a small positive value (1e-10). Note "
104             "that this pass can change the execution result of the model. So, use it only when the "
105             "impact is known to be acceptable.");
106
107   arser.add_argument("--fuse_preactivation_batchnorm")
108       .nargs(0)
109       .required(false)
110       .default_value(false)
111       .help("This will fuse BatchNorm operators of pre-activations to Convolution operator");
112
113   arser.add_argument("--remove_redundant_transpose")
114       .nargs(0)
115       .required(false)
116       .default_value(false)
117       .help("This will fuse or remove subsequent Transpose operators");
118
119   arser.add_argument("--replace_cw_mul_add_with_depthwise_conv")
120       .nargs(0)
121       .required(false)
122       .default_value(false)
123       .help("This will replace channel-wise mul/add with DepthwiseConv2D operator");
124
125   arser.add_argument("--resolve_customop_add")
126       .nargs(0)
127       .required(false)
128       .default_value(false)
129       .help("This will convert Custom(Add) to Add operator");
130
131   arser.add_argument("--resolve_customop_batchmatmul")
132       .nargs(0)
133       .required(false)
134       .default_value(false)
135       .help("This will convert Custom(BatchMatmul) to BatchMatmul operator");
136
137   arser.add_argument("--resolve_customop_matmul")
138       .nargs(0)
139       .required(false)
140       .default_value(false)
141       .help("This will convert Custom(Matmul) to Matmul operator");
142
143   arser.add_argument("--shuffle_weight_to_16x1float32")
144       .nargs(0)
145       .required(false)
146       .default_value(false)
147       .help("This will convert weight format of FullyConnected to SHUFFLED16x1FLOAT32. Note that "
148             "it only converts weights whose row is a multiple of 16");
149
150   arser.add_argument("--substitute_pack_to_reshape")
151       .nargs(0)
152       .required(false)
153       .default_value(false)
154       .help("This will convert single input Pack to Reshape");
155
156   arser.add_argument("--mute_warnings")
157       .nargs(0)
158       .required(false)
159       .default_value(false)
160       .help("This will turn off warning messages");
161
162   arser.add_argument("--disable_validation")
163       .nargs(0)
164       .required(false)
165       .default_value(false)
166       .help("This will turn off operator validations. May help input model investigation.");
167
168   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
169   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
170
171   // sparsification argument
172   arser.add_argument("--sparsify_tensor")
173       .nargs(1)
174       .type(arser::DataType::STR)
175       .required(false)
176       .help("Tensor name that you want to sparsify");
177
178   arser.add_argument("--sparsify_traversal_order")
179       .nargs(1)
180       .type(arser::DataType::STR)
181       .required(false)
182       .default_value("0,1,2,3")
183       .help("Traversal order of dimensions. Default value: 0,1,2,3");
184
185   arser.add_argument("--sparsify_format")
186       .nargs(1)
187       .type(arser::DataType::STR)
188       .required(false)
189       .default_value("d,s")
190       .help("Format of each dimension. 'd' stands for dense, 's' stands for sparse(CSR). Default "
191             "value: d,s");
192
193   arser.add_argument("--sparsify_block_size")
194       .nargs(1)
195       .type(arser::DataType::STR)
196       .required(false)
197       .help("Size of each block dimension");
198
199   arser.add_argument("--sparsify_block_map")
200       .nargs(1)
201       .type(arser::DataType::STR)
202       .required(false)
203       .default_value("0,1")
204       .help("Map from block dimension to the original tensor dimension. Default value: 0,1");
205
206   try
207   {
208     arser.parse(argc, argv);
209   }
210   catch (const std::runtime_error &err)
211   {
212     std::cout << err.what() << std::endl;
213     std::cout << arser;
214     return 255;
215   }
216
217   if (arser.get<bool>("--all"))
218   {
219     options->enable(Algorithms::FuseBCQ);
220     options->enable(Algorithms::FuseInstanceNorm);
221     options->enable(Algorithms::ResolveCustomOpAdd);
222     options->enable(Algorithms::ResolveCustomOpBatchMatMul);
223     options->enable(Algorithms::ResolveCustomOpMatMul);
224     options->enable(Algorithms::RemoveRedundantTranspose);
225     options->enable(Algorithms::SubstitutePackToReshape);
226   }
227   if (arser.get<bool>("--fold_dequantize"))
228     options->enable(Algorithms::FoldDequantize);
229   if (arser.get<bool>("--fuse_activation_function"))
230     options->enable(Algorithms::FuseActivationFunction);
231   if (arser.get<bool>("--fuse_add_with_tconv"))
232     options->enable(Algorithms::FuseAddWithTConv);
233   if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
234     options->enable(Algorithms::FuseBatchNormWithTConv);
235   if (arser.get<bool>("--fuse_bcq"))
236     options->enable(Algorithms::FuseBCQ);
237   if (arser.get<bool>("--fuse_instnorm"))
238     options->enable(Algorithms::FuseInstanceNorm);
239   if (arser.get<bool>("--make_batchnorm_gamma_positive"))
240     options->enable(Algorithms::MakeBatchNormGammaPositive);
241   if (arser.get<bool>("--fuse_preactivation_batchnorm"))
242     options->enable(Algorithms::FusePreActivationBatchNorm);
243   if (arser.get<bool>("--remove_redundant_transpose"))
244     options->enable(Algorithms::RemoveRedundantTranspose);
245   if (arser.get<bool>("--replace_cw_mul_add_with_depthwise_conv"))
246     options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
247   if (arser.get<bool>("--resolve_customop_add"))
248     options->enable(Algorithms::ResolveCustomOpAdd);
249   if (arser.get<bool>("--resolve_customop_batchmatmul"))
250     options->enable(Algorithms::ResolveCustomOpBatchMatMul);
251   if (arser.get<bool>("--resolve_customop_matmul"))
252     options->enable(Algorithms::ResolveCustomOpMatMul);
253   if (arser.get<bool>("--shuffle_weight_to_16x1float32"))
254     options->enable(Algorithms::ShuffleWeightTo16x1Float32);
255   if (arser.get<bool>("--substitute_pack_to_reshape"))
256     options->enable(Algorithms::SubstitutePackToReshape);
257
258   if (arser.get<bool>("--mute_warnings"))
259     settings->set(luci::UserSettings::Key::MuteWarnings, true);
260   if (arser.get<bool>("--disable_validation"))
261     settings->set(luci::UserSettings::Key::DisableValidation, true);
262
263   std::string input_path = arser.get<std::string>("input");
264   std::string output_path = arser.get<std::string>("output");
265
266   if (arser["--sparsify_tensor"])
267   {
268     options->enable(Algorithms::SparsifyTensorPass);
269     options->param(AlgorithmParameters::Sparsify_tensor_name,
270                    arser.get<std::string>("--sparsify_tensor"));
271     options->param(AlgorithmParameters::Sparsify_traversal_order,
272                    arser.get<std::string>("--sparsify_traversal_order"));
273     options->param(AlgorithmParameters::Sparsify_format,
274                    arser.get<std::string>("--sparsify_format"));
275     if (arser["--sparsify_block_size"])
276       options->param(AlgorithmParameters::Sparsify_block_size,
277                      arser.get<std::string>("--sparsify_block_size"));
278     else
279     {
280       std::cerr << "ERROR: Block size not provided" << std::endl;
281       return 255;
282     }
283     options->param(AlgorithmParameters::Sparsify_block_map,
284                    arser.get<std::string>("--sparsify_block_map"));
285   }
286
287   // Load model from the file
288   foder::FileLoader file_loader{input_path};
289   std::vector<char> model_data;
290
291   try
292   {
293     model_data = file_loader.load();
294   }
295   catch (const std::runtime_error &err)
296   {
297     std::cerr << err.what() << std::endl;
298     return EXIT_FAILURE;
299   }
300
301   flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
302   if (!circle::VerifyModelBuffer(verifier))
303   {
304     std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
305     return EXIT_FAILURE;
306   }
307
308   const circle::Model *circle_model = circle::GetModel(model_data.data());
309   if (circle_model == nullptr)
310   {
311     std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
312     return EXIT_FAILURE;
313   }
314
315   // Import from input Circle file
316   luci::Importer importer;
317   auto module = importer.importModule(circle_model);
318
319   // call luci optimizations for module
320   optimizer.optimize(module.get());
321
322   for (size_t idx = 0; idx < module->size(); ++idx)
323   {
324     auto graph = module->graph(idx);
325
326     // call luci optimizations for graph
327     optimizer.optimize(graph);
328     optimizer.sparsify(graph);
329
330     if (!luci::validate(graph))
331     {
332       if (settings->get(luci::UserSettings::Key::DisableValidation))
333         std::cerr << "WARNING: Optimized graph is invalid" << std::endl;
334       else
335       {
336         std::cerr << "ERROR: Optimized graph is invalid" << std::endl;
337         return 255;
338       }
339     }
340   }
341
342   // Export to output Circle file
343   luci::CircleExporter exporter;
344
345   luci::CircleFileExpContract contract(module.get(), output_path);
346
347   if (!exporter.invoke(&contract))
348   {
349     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
350     return 255;
351   }
352
353   return 0;
354 }