2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <luci/ImporterEx.h>
18 #include <luci/CircleOptimizer.h>
19 #include <luci/DynamicBatchToSingleBatch.h>
20 #include <luci/Service/ChangeOutputs.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24 #include <luci/UserSettings.h>
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
37 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
38 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
40 void print_version(void)
42 std::cout << "circle2circle version " << vconone::get_string() << std::endl;
43 std::cout << vconone::get_copyright() << std::endl;
46 void csv_tokenize(const std::string &data, std::vector<std::string> &result)
48 const char delim = ',';
50 std::stringstream ss(data);
52 while (std::getline(ss, token, delim))
53 result.push_back(token);
56 void add_switch(arser::Arser &arser, const char *opt, const char *desc)
58 arser.add_argument(opt).nargs(0).default_value(false).help(desc);
61 int entry(int argc, char **argv)
63 // Simple argument parser (based on map)
64 luci::CircleOptimizer optimizer;
66 auto options = optimizer.options();
67 auto settings = luci::UserSettings::settings();
69 arser::Arser arser("circle2circle provides circle model optimization and transformations");
71 arser::Helper::add_version(arser, print_version);
72 arser::Helper::add_verbose(arser);
74 add_switch(arser, "--fold_add_v2", "This will fold AddV2 operators with constant inputs");
75 add_switch(arser, "--fold_cast", "This will fold Cast operators with constant input");
76 add_switch(arser, "--fold_densify",
77 "This will fold Densify operators with sparse constant input");
78 add_switch(arser, "--fold_dequantize", "This will fold dequantize op");
79 add_switch(arser, "--fold_dwconv",
80 "This will fold Depthwise Convolution operator with constant inputs");
81 add_switch(arser, "--fold_fully_connected",
82 "This will fold FullyConnected operator with constant inputs");
83 add_switch(arser, "--fold_gather", "This will fold Gather operator");
84 add_switch(arser, "--fold_sparse_to_dense", "This will fold SparseToDense operator");
85 add_switch(arser, "--forward_reshape_to_unaryop",
86 "This will move Reshape after UnaryOp for centain condition");
87 add_switch(arser, "--forward_transpose_op",
88 "This will move Transpose Op forward if possible (for further optimization)");
89 add_switch(arser, "--fuse_activation_function",
90 "This will fuse Activation function to a preceding operator");
91 add_switch(arser, "--fuse_add_with_fully_connected",
92 "This will fuse Add operator to FullyConnected operator");
93 add_switch(arser, "--fuse_add_with_tconv",
94 "This will fuse Add operator to Transposed Convolution operator");
95 add_switch(arser, "--fuse_batchnorm_with_conv",
96 "This will fuse BatchNorm operators to Convolution operator");
97 add_switch(arser, "--fuse_batchnorm_with_dwconv",
98 "This will fuse BatchNorm operators to Depthwise Convolution operator");
99 add_switch(arser, "--fuse_batchnorm_with_tconv",
100 "This will fuse BatchNorm operators to Transposed Convolution operator");
101 add_switch(arser, "--fuse_bcq", "This will fuse operators and apply Binary Coded Quantization");
102 add_switch(arser, "--fuse_instnorm", "This will fuse operators to InstanceNorm operator");
103 add_switch(arser, "--fuse_mean_with_mean",
104 "This will fuse two Mean operations when they follow one by one. This will fold them "
105 "into one operation and merge reduction indices.");
106 add_switch(arser, "--fuse_transpose_with_mean",
107 "This will fuse Mean operation with a preceding Transpose under certain conditions.");
108 add_switch(arser, "--make_batchnorm_gamma_positive",
109 "This will make negative gamma of BatchNorm into a small positive value (1e-10). "
110 "Note that this pass can change the execution result of the model. So, use it only "
111 "when the impact is known to be acceptable.");
112 add_switch(arser, "--fuse_preactivation_batchnorm",
113 "This will fuse BatchNorm operators of pre-activations to Convolution operator");
114 add_switch(arser, "--fuse_prelu", "This will fuse operators to PReLU operator");
115 add_switch(arser, "--fuse_gelu", "This will fuse operators to GeLU operator");
116 add_switch(arser, "--remove_duplicate_const", "This will remove all duplicate constant nodes");
117 add_switch(arser, "--remove_fakequant", "This will remove FakeQuant operators");
118 add_switch(arser, "--remove_quantdequant", "This will remove Quantize-Dequantize sequence");
119 add_switch(arser, "--remove_redundant_quantize", "This will remove redundant Quantize operators");
120 add_switch(arser, "--remove_redundant_reshape",
121 "This will fuse or remove subsequent Reshape operators");
122 add_switch(arser, "--remove_redundant_transpose",
123 "This will fuse or remove subsequent Transpose operators");
124 add_switch(arser, "--remove_unnecessary_reshape",
125 "This will remove unnecessary reshape operators");
126 add_switch(arser, "--remove_unnecessary_slice", "This will remove unnecessary slice operators");
127 add_switch(arser, "--remove_unnecessary_strided_slice",
128 "This will remove unnecessary strided slice operators");
129 add_switch(arser, "--remove_unnecessary_split", "This will remove unnecessary split operators");
130 add_switch(arser, "--replace_cw_mul_add_with_depthwise_conv",
131 "This will replace channel-wise mul/add with DepthwiseConv2D operator");
132 add_switch(arser, "--replace_sub_with_add", "This will replace sub with add operator");
133 add_switch(arser, "--resolve_customop_add", "This will convert Custom(Add) to Add operator");
134 add_switch(arser, "--resolve_customop_batchmatmul",
135 "This will convert Custom(BatchMatmul) to BatchMatmul operator");
136 add_switch(arser, "--resolve_customop_matmul",
137 "This will convert Custom(Matmul) to Matmul operator");
138 add_switch(arser, "--resolve_customop_max_pool_with_argmax",
139 "This will convert Custom(MaxPoolWithArgmax) to equivalent set of operators");
140 add_switch(arser, "--resolve_customop_splitv",
141 "This will convert Custom(SplitV) to SplitV operator");
142 add_switch(arser, "--shuffle_weight_to_16x1float32",
143 "This will convert weight format of FullyConnected to SHUFFLED16x1FLOAT32. Note that "
144 "it only converts weights whose row is a multiple of 16");
145 add_switch(arser, "--replace_non_const_fc_with_batch_matmul",
146 "Replace FullyConnected with BatchMatMul when its weight is non-constant");
147 add_switch(arser, "--substitute_pack_to_reshape",
148 "This will convert single input Pack to Reshape");
149 add_switch(arser, "--substitute_padv2_to_pad",
150 "This will convert certain condition PadV2 to Pad");
151 add_switch(arser, "--substitute_splitv_to_split",
152 "This will convert certain condition SplitV to Split operator");
153 add_switch(arser, "--substitute_squeeze_to_reshape",
154 "This will convert certain condition Squeeze to Reshape");
155 add_switch(arser, "--substitute_strided_slice_to_reshape",
156 "This will convert certain condition Strided_Slice to Reshape");
157 add_switch(arser, "--substitute_transpose_to_reshape",
158 "This will convert single input Transpose to Reshape");
159 add_switch(arser, "--expand_broadcast_const", "This will expand broadcastable constant inputs");
160 add_switch(arser, "--unroll_unidirseqlstm", "Unroll UnidirectionalSequenceLSTM operator.");
161 add_switch(arser, "--convert_nchw_to_nhwc",
162 "Experimental: This will convert NCHW operators to NHWC under the assumption that "
163 "input model is NCHW.");
164 add_switch(arser, "--nchw_to_nhwc_input_shape",
165 "Convert the input shape of the model (argument for --convert_nchw_to_nhwc).");
166 add_switch(arser, "--nchw_to_nhwc_output_shape",
167 "Convert the output shape of the model (argument for --convert_nchw_to_nhwc).");
168 add_switch(arser, "--transform_min_max_to_relu6",
169 "Transform Minimum(6)-Maximum(0) pattern to Relu6 operator");
170 add_switch(arser, "--transform_min_relu_to_relu6",
171 "Transform Minimum(6)-Relu pattern to Relu6 operator");
172 add_switch(arser, "--decompose_hardswish",
173 "Decompose HardSwish operator to Add, Mul and Relu6 operators");
174 add_switch(arser, "--mute_warnings", "This will turn off warning messages");
175 add_switch(arser, "--disable_validation",
176 "This will turn off operator validations. May help input model investigation.");
177 add_switch(arser, "--generate_profile_data", "This will turn on profiling data generation.");
179 // Convert dynamic batch to single batch
180 // Users have to use this option only when the first dimension of rank 4 input (NHWC or NCHW)
181 // is dynamic. Remove this comment after non-rank 4 is supported.
182 add_switch(arser, "--dynamic_batch_to_single_batch",
183 "Convert dynamic batch size (first dimension) of inputs to 1.");
185 arser.add_argument("--change_outputs")
186 .help("Experimental: Change first subgraph output nodes to CSV names");
188 arser.add_argument("input").help("Input circle model");
189 arser.add_argument("output").help("Output circle model");
191 // sparsification argument
192 arser.add_argument("--sparsify_tensor").help("Tensor name that you want to sparsify");
194 arser.add_argument("--sparsify_traversal_order")
195 .default_value("0,1,2,3")
196 .help("Traversal order of dimensions. Default value: 0,1,2,3");
198 arser.add_argument("--sparsify_format")
199 .default_value("d,s")
200 .help("Format of each dimension. 'd' stands for dense, 's' stands for sparse(CSR). Default "
203 arser.add_argument("--sparsify_block_size").help("Size of each block dimension");
205 arser.add_argument("--sparsify_block_map")
206 .default_value("0,1")
207 .help("Map from block dimension to the original tensor dimension. Default value: 0,1");
211 arser.parse(argc, argv);
213 catch (const std::runtime_error &err)
215 std::cerr << err.what() << std::endl;
220 if (arser.get<bool>("--verbose"))
222 // The third parameter of setenv means REPLACE.
223 // If REPLACE is zero, it does not overwrite an existing value.
224 setenv("LUCI_LOG", "100", 0);
226 if (arser.get<bool>("--fold_add_v2"))
227 options->enable(Algorithms::FoldAddV2);
228 if (arser.get<bool>("--fold_cast"))
229 options->enable(Algorithms::FoldCast);
230 if (arser.get<bool>("--fold_densify"))
231 options->enable(Algorithms::FoldDensify);
232 if (arser.get<bool>("--fold_dequantize"))
233 options->enable(Algorithms::FoldDequantize);
234 if (arser.get<bool>("--fold_dwconv"))
235 options->enable(Algorithms::FoldDepthwiseConv2D);
236 if (arser.get<bool>("--fold_fully_connected"))
237 options->enable(Algorithms::FoldFullyConnected);
238 if (arser.get<bool>("--fold_gather"))
239 options->enable(Algorithms::FoldGather);
240 if (arser.get<bool>("--fold_sparse_to_dense"))
241 options->enable(Algorithms::FoldSparseToDense);
242 if (arser.get<bool>("--forward_reshape_to_unaryop"))
243 options->enable(Algorithms::ForwardReshapeToUnaryOp);
244 if (arser.get<bool>("--forward_transpose_op"))
245 options->enable(Algorithms::ForwardTransposeOp);
246 if (arser.get<bool>("--fuse_activation_function"))
247 options->enable(Algorithms::FuseActivationFunction);
248 if (arser.get<bool>("--fuse_batchnorm_with_conv"))
249 options->enable(Algorithms::FuseBatchNormWithConv);
250 if (arser.get<bool>("--fuse_add_with_fully_connected"))
251 options->enable(Algorithms::FuseAddWithFullyConnected);
252 if (arser.get<bool>("--fuse_add_with_tconv"))
253 options->enable(Algorithms::FuseAddWithTConv);
254 if (arser.get<bool>("--fuse_batchnorm_with_dwconv"))
255 options->enable(Algorithms::FuseBatchNormWithDwConv);
256 if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
257 options->enable(Algorithms::FuseBatchNormWithTConv);
258 if (arser.get<bool>("--fuse_bcq"))
259 options->enable(Algorithms::FuseBCQ);
260 if (arser.get<bool>("--fuse_instnorm"))
261 options->enable(Algorithms::FuseInstanceNorm);
262 if (arser.get<bool>("--fuse_mean_with_mean"))
263 options->enable(Algorithms::FuseMeanWithMean);
264 if (arser.get<bool>("--make_batchnorm_gamma_positive"))
265 options->enable(Algorithms::MakeBatchNormGammaPositive);
266 if (arser.get<bool>("--fuse_preactivation_batchnorm"))
267 options->enable(Algorithms::FusePreActivationBatchNorm);
268 if (arser.get<bool>("--fuse_prelu"))
269 options->enable(Algorithms::FusePRelu);
270 if (arser.get<bool>("--fuse_gelu"))
271 options->enable(Algorithms::FuseGelu);
272 if (arser.get<bool>("--fuse_transpose_with_mean"))
273 options->enable(Algorithms::FuseTransposeWithMean);
274 if (arser.get<bool>("--remove_duplicate_const"))
275 options->enable(Algorithms::RemoveDuplicateConst);
276 if (arser.get<bool>("--remove_fakequant"))
277 options->enable(Algorithms::RemoveFakeQuant);
278 if (arser.get<bool>("--remove_quantdequant"))
279 options->enable(Algorithms::RemoveQuantDequantSeq);
280 if (arser.get<bool>("--remove_redundant_quantize"))
281 options->enable(Algorithms::RemoveRedundantQuantize);
282 if (arser.get<bool>("--remove_redundant_reshape"))
283 options->enable(Algorithms::RemoveRedundantReshape);
284 if (arser.get<bool>("--remove_redundant_transpose"))
285 options->enable(Algorithms::RemoveRedundantTranspose);
286 if (arser.get<bool>("--remove_unnecessary_reshape"))
287 options->enable(Algorithms::RemoveUnnecessaryReshape);
288 if (arser.get<bool>("--remove_unnecessary_slice"))
289 options->enable(Algorithms::RemoveUnnecessarySlice);
290 if (arser.get<bool>("--remove_unnecessary_strided_slice"))
291 options->enable(Algorithms::RemoveUnnecessaryStridedSlice);
292 if (arser.get<bool>("--remove_unnecessary_split"))
293 options->enable(Algorithms::RemoveUnnecessarySplit);
294 if (arser.get<bool>("--replace_cw_mul_add_with_depthwise_conv"))
295 options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
296 if (arser.get<bool>("--replace_sub_with_add"))
297 options->enable(Algorithms::ReplaceSubWithAdd);
298 if (arser.get<bool>("--resolve_customop_add"))
299 options->enable(Algorithms::ResolveCustomOpAdd);
300 if (arser.get<bool>("--resolve_customop_batchmatmul"))
301 options->enable(Algorithms::ResolveCustomOpBatchMatMul);
302 if (arser.get<bool>("--resolve_customop_matmul"))
303 options->enable(Algorithms::ResolveCustomOpMatMul);
304 if (arser.get<bool>("--resolve_customop_max_pool_with_argmax"))
305 options->enable(Algorithms::ResolveCustomOpMaxPoolWithArgmax);
306 if (arser.get<bool>("--resolve_customop_splitv"))
307 options->enable(Algorithms::ResolveCustomOpSplitV);
308 if (arser.get<bool>("--shuffle_weight_to_16x1float32"))
309 options->enable(Algorithms::ShuffleWeightTo16x1Float32);
310 if (arser.get<bool>("--replace_non_const_fc_with_batch_matmul"))
311 options->enable(Algorithms::ReplaceNonConstFCWithBatchMatMul);
312 if (arser.get<bool>("--substitute_pack_to_reshape"))
313 options->enable(Algorithms::SubstitutePackToReshape);
314 if (arser.get<bool>("--substitute_padv2_to_pad"))
315 options->enable(Algorithms::SubstitutePadV2ToPad);
316 if (arser.get<bool>("--substitute_splitv_to_split"))
317 options->enable(Algorithms::SubstituteSplitVToSplit);
318 if (arser.get<bool>("--substitute_squeeze_to_reshape"))
319 options->enable(Algorithms::SubstituteSqueezeToReshape);
320 if (arser.get<bool>("--substitute_strided_slice_to_reshape"))
321 options->enable(Algorithms::SubstituteStridedSliceToReshape);
322 if (arser.get<bool>("--substitute_transpose_to_reshape"))
323 options->enable(Algorithms::SubstituteTransposeToReshape);
324 if (arser.get<bool>("--transform_min_max_to_relu6"))
325 options->enable(Algorithms::TransformMinMaxToRelu6Pass);
326 if (arser.get<bool>("--transform_min_relu_to_relu6"))
327 options->enable(Algorithms::TransformMinReluToRelu6Pass);
328 if (arser.get<bool>("--decompose_hardswish"))
329 options->enable(Algorithms::DecomposeHardSwishPass);
330 if (arser.get<bool>("--expand_broadcast_const"))
331 options->enable(Algorithms::ExpandBroadcastConst);
332 if (arser.get<bool>("--unroll_unidirseqlstm"))
333 options->enable(Algorithms::UnrollUnidirSeqLSTM);
335 if (arser.get<bool>("--mute_warnings"))
336 settings->set(luci::UserSettings::Key::MuteWarnings, true);
337 if (arser.get<bool>("--disable_validation"))
338 settings->set(luci::UserSettings::Key::DisableValidation, true);
339 if (arser.get<bool>("--generate_profile_data"))
340 settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
342 std::string input_path = arser.get<std::string>("input");
343 std::string output_path = arser.get<std::string>("output");
345 if (arser["--sparsify_tensor"])
347 options->enable(Algorithms::SparsifyTensorPass);
348 options->param(AlgorithmParameters::Sparsify_tensor_name,
349 arser.get<std::string>("--sparsify_tensor"));
350 options->param(AlgorithmParameters::Sparsify_traversal_order,
351 arser.get<std::string>("--sparsify_traversal_order"));
352 options->param(AlgorithmParameters::Sparsify_format,
353 arser.get<std::string>("--sparsify_format"));
354 if (arser["--sparsify_block_size"])
355 options->param(AlgorithmParameters::Sparsify_block_size,
356 arser.get<std::string>("--sparsify_block_size"));
359 std::cerr << "ERROR: Block size not provided" << std::endl;
362 options->param(AlgorithmParameters::Sparsify_block_map,
363 arser.get<std::string>("--sparsify_block_map"));
366 if (arser.get<bool>("--convert_nchw_to_nhwc"))
368 options->enable(Algorithms::ConvertNCHWToNHWC);
369 if (arser.get<bool>("--nchw_to_nhwc_input_shape"))
370 options->param(AlgorithmParameters::NCHW_to_NHWC_input_shape, "true");
371 if (arser.get<bool>("--nchw_to_nhwc_output_shape"))
372 options->param(AlgorithmParameters::NCHW_to_NHWC_output_shape, "true");
375 // Change output nodes
376 bool change_outputs = false;
377 std::vector<std::string> new_outputs;
378 if (arser["--change_outputs"])
380 change_outputs = true;
381 auto csv_nodes = arser.get<std::string>("--change_outputs");
382 csv_tokenize(csv_nodes, new_outputs);
385 bool dynamic_batch_to_single_batch = false;
386 if (arser.get<bool>("--dynamic_batch_to_single_batch"))
388 dynamic_batch_to_single_batch = true;
391 // Import from input Circle file
392 luci::ImporterEx importerex;
393 auto module = importerex.importVerifyModule(input_path);
394 if (module.get() == nullptr)
397 // Convert dynamic batch to single batch
398 // Why here? It has to be done before 'optimize', because most optimization
399 // passes are written based on static shapes
400 if (dynamic_batch_to_single_batch)
402 luci::dynamic_batch_to_single_batch(module.get());
404 if (!luci::validate_shape(module.get()))
406 if (settings->get(luci::UserSettings::Key::DisableValidation))
408 << "WARNING: Invalid shape detected after converting dynamic batch to single batch"
412 std::cerr << "ERROR: Invalid shape detected after converting dynamic batch to single batch"
421 auto graph = module->graph(0);
422 luci::change_outputs(graph, new_outputs);
425 // call luci optimizations for module
426 optimizer.optimize(module.get());
428 for (size_t idx = 0; idx < module->size(); ++idx)
430 auto graph = module->graph(idx);
432 // call luci optimizations for graph
433 optimizer.optimize(graph);
434 optimizer.sparsify(graph);
436 if (!luci::validate(graph))
438 if (settings->get(luci::UserSettings::Key::DisableValidation))
439 std::cerr << "WARNING: Optimized graph is invalid" << std::endl;
442 std::cerr << "ERROR: Optimized graph is invalid" << std::endl;
448 // Export to output Circle file
449 luci::CircleExporter exporter;
451 luci::CircleFileExpContract contract(module.get(), output_path);
453 if (!exporter.invoke(&contract))
455 std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;