Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle2circle / src / Circle2Circle.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <luci/ImporterEx.h>
18 #include <luci/CircleOptimizer.h>
19 #include <luci/DynamicBatchToSingleBatch.h>
20 #include <luci/Service/ChangeOutputs.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24 #include <luci/UserSettings.h>
25
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
29
30 #include <functional>
31 #include <iostream>
32 #include <sstream>
33 #include <string>
34 #include <vector>
35 #include <cstdlib>
36
37 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
38 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
39
40 void print_version(void)
41 {
42   std::cout << "circle2circle version " << vconone::get_string() << std::endl;
43   std::cout << vconone::get_copyright() << std::endl;
44 }
45
46 void csv_tokenize(const std::string &data, std::vector<std::string> &result)
47 {
48   const char delim = ',';
49   std::string token;
50   std::stringstream ss(data);
51
52   while (std::getline(ss, token, delim))
53     result.push_back(token);
54 }
55
56 void add_switch(arser::Arser &arser, const char *opt, const char *desc)
57 {
58   arser.add_argument(opt).nargs(0).default_value(false).help(desc);
59 }
60
61 int entry(int argc, char **argv)
62 {
63   // Simple argument parser (based on map)
64   luci::CircleOptimizer optimizer;
65
66   auto options = optimizer.options();
67   auto settings = luci::UserSettings::settings();
68
69   arser::Arser arser("circle2circle provides circle model optimization and transformations");
70
71   arser::Helper::add_version(arser, print_version);
72   arser::Helper::add_verbose(arser);
73
74   add_switch(arser, "--fold_add_v2", "This will fold AddV2 operators with constant inputs");
75   add_switch(arser, "--fold_cast", "This will fold Cast operators with constant input");
76   add_switch(arser, "--fold_densify",
77              "This will fold Densify operators with sparse constant input");
78   add_switch(arser, "--fold_dequantize", "This will fold dequantize op");
79   add_switch(arser, "--fold_dwconv",
80              "This will fold Depthwise Convolution operator with constant inputs");
81   add_switch(arser, "--fold_fully_connected",
82              "This will fold FullyConnected operator with constant inputs");
83   add_switch(arser, "--fold_gather", "This will fold Gather operator");
84   add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator");
85   add_switch(arser, "--forward_reshape_to_unaryop",
86              "This will move Reshape after UnaryOp for centain condition");
87   add_switch(arser, "--forward_transpose_op",
88              "This will move Transpose Op forward if possible (for further optimization)");
89   add_switch(arser, "--fuse_activation_function",
90              "This will fuse Activation function to a preceding operator");
91   add_switch(arser, "--fuse_add_with_fully_connected",
92              "This will fuse Add operator to FullyConnected operator");
93   add_switch(arser, "--fuse_add_with_tconv",
94              "This will fuse Add operator to Transposed Convolution operator");
95   add_switch(arser, "--fuse_batchnorm_with_conv",
96              "This will fuse BatchNorm operators to Convolution operator");
97   add_switch(arser, "--fuse_batchnorm_with_dwconv",
98              "This will fuse BatchNorm operators to Depthwise Convolution operator");
99   add_switch(arser, "--fuse_batchnorm_with_tconv",
100              "This will fuse BatchNorm operators to Transposed Convolution operator");
101   add_switch(arser, "--fuse_bcq", "This will fuse operators and apply Binary Coded Quantization");
102   add_switch(arser, "--fuse_instnorm", "This will fuse operators to InstanceNorm operator");
103   add_switch(arser, "--fuse_mean_with_mean",
104              "This will fuse two Mean operations when they follow one by one. This will fold them "
105              "into one operation and merge reduction indices.");
106   add_switch(arser, "--fuse_transpose_with_mean",
107              "This will fuse Mean operation with a preceding Transpose under certain conditions.");
108   add_switch(arser, "--make_batchnorm_gamma_positive",
109              "This will make negative gamma of BatchNorm into a small positive value (1e-10). "
110              "Note that this pass can change the execution result of the model. So, use it only "
111              "when the impact is known to be acceptable.");
112   add_switch(arser, "--fuse_preactivation_batchnorm",
113              "This will fuse BatchNorm operators of pre-activations to Convolution operator");
114   add_switch(arser, "--fuse_prelu", "This will fuse operators to PReLU operator");
115   add_switch(arser, "--fuse_gelu", "This will fuse operators to GeLU operator");
116   add_switch(arser, "--remove_duplicate_const", "This will remove all duplicate constant nodes");
117   add_switch(arser, "--remove_fakequant", "This will remove FakeQuant operators");
118   add_switch(arser, "--remove_quantdequant", "This will remove Quantize-Dequantize sequence");
119   add_switch(arser, "--remove_redundant_quantize", "This will remove redundant Quantize operators");
120   add_switch(arser, "--remove_redundant_reshape",
121              "This will fuse or remove subsequent Reshape operators");
122   add_switch(arser, "--remove_redundant_transpose",
123              "This will fuse or remove subsequent Transpose operators");
124   add_switch(arser, "--remove_unnecessary_reshape",
125              "This will remove unnecessary reshape operators");
126   add_switch(arser, "--remove_unnecessary_slice", "This will remove unnecessary slice operators");
127   add_switch(arser, "--remove_unnecessary_strided_slice",
128              "This will remove unnecessary strided slice operators");
129   add_switch(arser, "--remove_unnecessary_split", "This will remove unnecessary split operators");
130   add_switch(arser, "--replace_cw_mul_add_with_depthwise_conv",
131              "This will replace channel-wise mul/add with DepthwiseConv2D operator");
132   add_switch(arser, "--replace_sub_with_add", "This will replace sub with add operator");
133   add_switch(arser, "--resolve_customop_add", "This will convert Custom(Add) to Add operator");
134   add_switch(arser, "--resolve_customop_batchmatmul",
135              "This will convert Custom(BatchMatmul) to BatchMatmul operator");
136   add_switch(arser, "--resolve_customop_matmul",
137              "This will convert Custom(Matmul) to Matmul operator");
138   add_switch(arser, "--resolve_customop_max_pool_with_argmax",
139              "This will convert Custom(MaxPoolWithArgmax) to equivalent set of operators");
140   add_switch(arser, "--resolve_customop_splitv",
141              "This will convert Custom(SplitV) to SplitV operator");
142   add_switch(arser, "--shuffle_weight_to_16x1float32",
143              "This will convert weight format of FullyConnected to SHUFFLED16x1FLOAT32. Note that "
144              "it only converts weights whose row is a multiple of 16");
145   add_switch(arser, "--replace_non_const_fc_with_batch_matmul",
146              "Replace FullyConnected with BatchMatMul when its weight is non-constant");
147   add_switch(arser, "--substitute_pack_to_reshape",
148              "This will convert single input Pack to Reshape");
149   add_switch(arser, "--substitute_padv2_to_pad",
150              "This will convert certain condition PadV2 to Pad");
151   add_switch(arser, "--substitute_splitv_to_split",
152              "This will convert certain condition SplitV to Split operator");
153   add_switch(arser, "--substitute_squeeze_to_reshape",
154              "This will convert certain condition Squeeze to Reshape");
155   add_switch(arser, "--substitute_strided_slice_to_reshape",
156              "This will convert certain condition Strided_Slice to Reshape");
157   add_switch(arser, "--substitute_transpose_to_reshape",
158              "This will convert single input Transpose to Reshape");
159   add_switch(arser, "--expand_broadcast_const", "This will expand broadcastable constant inputs");
160   add_switch(arser, "--unroll_unidirseqlstm", "Unroll UnidirectionalSequenceLSTM operator.");
161   add_switch(arser, "--convert_nchw_to_nhwc",
162              "Experimental: This will convert NCHW operators to NHWC under the assumption that "
163              "input model is NCHW.");
164   add_switch(arser, "--nchw_to_nhwc_input_shape",
165              "Convert the input shape of the model (argument for --convert_nchw_to_nhwc).");
166   add_switch(arser, "--nchw_to_nhwc_output_shape",
167              "Convert the output shape of the model (argument for --convert_nchw_to_nhwc).");
168   add_switch(arser, "--transform_min_max_to_relu6",
169              "Transform Minimum(6)-Maximum(0) pattern to Relu6 operator");
170   add_switch(arser, "--transform_min_relu_to_relu6",
171              "Transform Minimum(6)-Relu pattern to Relu6 operator");
172   add_switch(arser, "--decompose_hardswish",
173              "Decompose HardSwish operator to Add, Mul and Relu6 operators");
174   add_switch(arser, "--mute_warnings", "This will turn off warning messages");
175   add_switch(arser, "--disable_validation",
176              "This will turn off operator validations. May help input model investigation.");
177   add_switch(arser, "--generate_profile_data", "This will turn on profiling data generation.");
178
179   // Convert dynamic batch to single batch
180   // Users have to use this option only when the first dimension of rank 4 input (NHWC or NCHW)
181   // is dynamic. Remove this comment after non-rank 4 is supported.
182   add_switch(arser, "--dynamic_batch_to_single_batch",
183              "Convert dynamic batch size (first dimension) of inputs to 1.");
184
185   arser.add_argument("--change_outputs")
186     .help("Experimental: Change first subgraph output nodes to CSV names");
187
188   arser.add_argument("input").help("Input circle model");
189   arser.add_argument("output").help("Output circle model");
190
191   // sparsification argument
192   arser.add_argument("--sparsify_tensor").help("Tensor name that you want to sparsify");
193
194   arser.add_argument("--sparsify_traversal_order")
195     .default_value("0,1,2,3")
196     .help("Traversal order of dimensions. Default value: 0,1,2,3");
197
198   arser.add_argument("--sparsify_format")
199     .default_value("d,s")
200     .help("Format of each dimension. 'd' stands for dense, 's' stands for sparse(CSR). Default "
201           "value: d,s");
202
203   arser.add_argument("--sparsify_block_size").help("Size of each block dimension");
204
205   arser.add_argument("--sparsify_block_map")
206     .default_value("0,1")
207     .help("Map from block dimension to the original tensor dimension. Default value: 0,1");
208
209   try
210   {
211     arser.parse(argc, argv);
212   }
213   catch (const std::runtime_error &err)
214   {
215     std::cerr << err.what() << std::endl;
216     std::cout << arser;
217     return 255;
218   }
219
220   if (arser.get<bool>("--verbose"))
221   {
222     // The third parameter of setenv means REPLACE.
223     // If REPLACE is zero, it does not overwrite an existing value.
224     setenv("LUCI_LOG", "100", 0);
225   }
226   if (arser.get<bool>("--fold_add_v2"))
227     options->enable(Algorithms::FoldAddV2);
228   if (arser.get<bool>("--fold_cast"))
229     options->enable(Algorithms::FoldCast);
230   if (arser.get<bool>("--fold_densify"))
231     options->enable(Algorithms::FoldDensify);
232   if (arser.get<bool>("--fold_dequantize"))
233     options->enable(Algorithms::FoldDequantize);
234   if (arser.get<bool>("--fold_dwconv"))
235     options->enable(Algorithms::FoldDepthwiseConv2D);
236   if (arser.get<bool>("--fold_fully_connected"))
237     options->enable(Algorithms::FoldFullyConnected);
238   if (arser.get<bool>("--fold_gather"))
239     options->enable(Algorithms::FoldGather);
240   if (arser.get<bool>("--fold_sparse_to_dense"))
241     options->enable(Algorithms::FoldSparseToDense);
242   if (arser.get<bool>("--forward_reshape_to_unaryop"))
243     options->enable(Algorithms::ForwardReshapeToUnaryOp);
244   if (arser.get<bool>("--forward_transpose_op"))
245     options->enable(Algorithms::ForwardTransposeOp);
246   if (arser.get<bool>("--fuse_activation_function"))
247     options->enable(Algorithms::FuseActivationFunction);
248   if (arser.get<bool>("--fuse_batchnorm_with_conv"))
249     options->enable(Algorithms::FuseBatchNormWithConv);
250   if (arser.get<bool>("--fuse_add_with_fully_connected"))
251     options->enable(Algorithms::FuseAddWithFullyConnected);
252   if (arser.get<bool>("--fuse_add_with_tconv"))
253     options->enable(Algorithms::FuseAddWithTConv);
254   if (arser.get<bool>("--fuse_batchnorm_with_dwconv"))
255     options->enable(Algorithms::FuseBatchNormWithDwConv);
256   if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
257     options->enable(Algorithms::FuseBatchNormWithTConv);
258   if (arser.get<bool>("--fuse_bcq"))
259     options->enable(Algorithms::FuseBCQ);
260   if (arser.get<bool>("--fuse_instnorm"))
261     options->enable(Algorithms::FuseInstanceNorm);
262   if (arser.get<bool>("--fuse_mean_with_mean"))
263     options->enable(Algorithms::FuseMeanWithMean);
264   if (arser.get<bool>("--make_batchnorm_gamma_positive"))
265     options->enable(Algorithms::MakeBatchNormGammaPositive);
266   if (arser.get<bool>("--fuse_preactivation_batchnorm"))
267     options->enable(Algorithms::FusePreActivationBatchNorm);
268   if (arser.get<bool>("--fuse_prelu"))
269     options->enable(Algorithms::FusePRelu);
270   if (arser.get<bool>("--fuse_gelu"))
271     options->enable(Algorithms::FuseGelu);
272   if (arser.get<bool>("--fuse_transpose_with_mean"))
273     options->enable(Algorithms::FuseTransposeWithMean);
274   if (arser.get<bool>("--remove_duplicate_const"))
275     options->enable(Algorithms::RemoveDuplicateConst);
276   if (arser.get<bool>("--remove_fakequant"))
277     options->enable(Algorithms::RemoveFakeQuant);
278   if (arser.get<bool>("--remove_quantdequant"))
279     options->enable(Algorithms::RemoveQuantDequantSeq);
280   if (arser.get<bool>("--remove_redundant_quantize"))
281     options->enable(Algorithms::RemoveRedundantQuantize);
282   if (arser.get<bool>("--remove_redundant_reshape"))
283     options->enable(Algorithms::RemoveRedundantReshape);
284   if (arser.get<bool>("--remove_redundant_transpose"))
285     options->enable(Algorithms::RemoveRedundantTranspose);
286   if (arser.get<bool>("--remove_unnecessary_reshape"))
287     options->enable(Algorithms::RemoveUnnecessaryReshape);
288   if (arser.get<bool>("--remove_unnecessary_slice"))
289     options->enable(Algorithms::RemoveUnnecessarySlice);
290   if (arser.get<bool>("--remove_unnecessary_strided_slice"))
291     options->enable(Algorithms::RemoveUnnecessaryStridedSlice);
292   if (arser.get<bool>("--remove_unnecessary_split"))
293     options->enable(Algorithms::RemoveUnnecessarySplit);
294   if (arser.get<bool>("--replace_cw_mul_add_with_depthwise_conv"))
295     options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
296   if (arser.get<bool>("--replace_sub_with_add"))
297     options->enable(Algorithms::ReplaceSubWithAdd);
298   if (arser.get<bool>("--resolve_customop_add"))
299     options->enable(Algorithms::ResolveCustomOpAdd);
300   if (arser.get<bool>("--resolve_customop_batchmatmul"))
301     options->enable(Algorithms::ResolveCustomOpBatchMatMul);
302   if (arser.get<bool>("--resolve_customop_matmul"))
303     options->enable(Algorithms::ResolveCustomOpMatMul);
304   if (arser.get<bool>("--resolve_customop_max_pool_with_argmax"))
305     options->enable(Algorithms::ResolveCustomOpMaxPoolWithArgmax);
306   if (arser.get<bool>("--resolve_customop_splitv"))
307     options->enable(Algorithms::ResolveCustomOpSplitV);
308   if (arser.get<bool>("--shuffle_weight_to_16x1float32"))
309     options->enable(Algorithms::ShuffleWeightTo16x1Float32);
310   if (arser.get<bool>("--replace_non_const_fc_with_batch_matmul"))
311     options->enable(Algorithms::ReplaceNonConstFCWithBatchMatMul);
312   if (arser.get<bool>("--substitute_pack_to_reshape"))
313     options->enable(Algorithms::SubstitutePackToReshape);
314   if (arser.get<bool>("--substitute_padv2_to_pad"))
315     options->enable(Algorithms::SubstitutePadV2ToPad);
316   if (arser.get<bool>("--substitute_splitv_to_split"))
317     options->enable(Algorithms::SubstituteSplitVToSplit);
318   if (arser.get<bool>("--substitute_squeeze_to_reshape"))
319     options->enable(Algorithms::SubstituteSqueezeToReshape);
320   if (arser.get<bool>("--substitute_strided_slice_to_reshape"))
321     options->enable(Algorithms::SubstituteStridedSliceToReshape);
322   if (arser.get<bool>("--substitute_transpose_to_reshape"))
323     options->enable(Algorithms::SubstituteTransposeToReshape);
324   if (arser.get<bool>("--transform_min_max_to_relu6"))
325     options->enable(Algorithms::TransformMinMaxToRelu6Pass);
326   if (arser.get<bool>("--transform_min_relu_to_relu6"))
327     options->enable(Algorithms::TransformMinReluToRelu6Pass);
328   if (arser.get<bool>("--decompose_hardswish"))
329     options->enable(Algorithms::DecomposeHardSwishPass);
330   if (arser.get<bool>("--expand_broadcast_const"))
331     options->enable(Algorithms::ExpandBroadcastConst);
332   if (arser.get<bool>("--unroll_unidirseqlstm"))
333     options->enable(Algorithms::UnrollUnidirSeqLSTM);
334
335   if (arser.get<bool>("--mute_warnings"))
336     settings->set(luci::UserSettings::Key::MuteWarnings, true);
337   if (arser.get<bool>("--disable_validation"))
338     settings->set(luci::UserSettings::Key::DisableValidation, true);
339   if (arser.get<bool>("--generate_profile_data"))
340     settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
341
342   std::string input_path = arser.get<std::string>("input");
343   std::string output_path = arser.get<std::string>("output");
344
345   if (arser["--sparsify_tensor"])
346   {
347     options->enable(Algorithms::SparsifyTensorPass);
348     options->param(AlgorithmParameters::Sparsify_tensor_name,
349                    arser.get<std::string>("--sparsify_tensor"));
350     options->param(AlgorithmParameters::Sparsify_traversal_order,
351                    arser.get<std::string>("--sparsify_traversal_order"));
352     options->param(AlgorithmParameters::Sparsify_format,
353                    arser.get<std::string>("--sparsify_format"));
354     if (arser["--sparsify_block_size"])
355       options->param(AlgorithmParameters::Sparsify_block_size,
356                      arser.get<std::string>("--sparsify_block_size"));
357     else
358     {
359       std::cerr << "ERROR: Block size not provided" << std::endl;
360       return 255;
361     }
362     options->param(AlgorithmParameters::Sparsify_block_map,
363                    arser.get<std::string>("--sparsify_block_map"));
364   }
365
366   if (arser.get<bool>("--convert_nchw_to_nhwc"))
367   {
368     options->enable(Algorithms::ConvertNCHWToNHWC);
369     if (arser.get<bool>("--nchw_to_nhwc_input_shape"))
370       options->param(AlgorithmParameters::NCHW_to_NHWC_input_shape, "true");
371     if (arser.get<bool>("--nchw_to_nhwc_output_shape"))
372       options->param(AlgorithmParameters::NCHW_to_NHWC_output_shape, "true");
373   }
374
375   // Change output nodes
376   bool change_outputs = false;
377   std::vector<std::string> new_outputs;
378   if (arser["--change_outputs"])
379   {
380     change_outputs = true;
381     auto csv_nodes = arser.get<std::string>("--change_outputs");
382     csv_tokenize(csv_nodes, new_outputs);
383   }
384
385   bool dynamic_batch_to_single_batch = false;
386   if (arser.get<bool>("--dynamic_batch_to_single_batch"))
387   {
388     dynamic_batch_to_single_batch = true;
389   }
390
391   // Import from input Circle file
392   luci::ImporterEx importerex;
393   auto module = importerex.importVerifyModule(input_path);
394   if (module.get() == nullptr)
395     return EXIT_FAILURE;
396
397   // Convert dynamic batch to single batch
398   // Why here? It has to be done before 'optimize', because most optimization
399   // passes are written based on static shapes
400   if (dynamic_batch_to_single_batch)
401   {
402     luci::dynamic_batch_to_single_batch(module.get());
403
404     if (!luci::validate_shape(module.get()))
405     {
406       if (settings->get(luci::UserSettings::Key::DisableValidation))
407         std::cerr
408           << "WARNING: Invalid shape detected after converting dynamic batch to single batch"
409           << std::endl;
410       else
411       {
412         std::cerr << "ERROR: Invalid shape detected after converting dynamic batch to single batch"
413                   << std::endl;
414         return 255;
415       }
416     }
417   }
418
419   if (change_outputs)
420   {
421     auto graph = module->graph(0);
422     luci::change_outputs(graph, new_outputs);
423   }
424
425   // call luci optimizations for module
426   optimizer.optimize(module.get());
427
428   for (size_t idx = 0; idx < module->size(); ++idx)
429   {
430     auto graph = module->graph(idx);
431
432     // call luci optimizations for graph
433     optimizer.optimize(graph);
434     optimizer.sparsify(graph);
435
436     if (!luci::validate(graph))
437     {
438       if (settings->get(luci::UserSettings::Key::DisableValidation))
439         std::cerr << "WARNING: Optimized graph is invalid" << std::endl;
440       else
441       {
442         std::cerr << "ERROR: Optimized graph is invalid" << std::endl;
443         return 255;
444       }
445     }
446   }
447
448   // Export to output Circle file
449   luci::CircleExporter exporter;
450
451   luci::CircleFileExpContract contract(module.get(), output_path);
452
453   if (!exporter.invoke(&contract))
454   {
455     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
456     return 255;
457   }
458
459   return 0;
460 }