2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <foder/FileLoader.h>
19 #include <luci/Importer.h>
20 #include <luci/CircleOptimizer.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24 #include <luci/UserSettings.h>
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
34 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
35 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
37 void print_version(void)
39 std::cout << "circle2circle version " << vconone::get_string() << std::endl;
40 std::cout << vconone::get_copyright() << std::endl;
43 int entry(int argc, char **argv)
45 // Simple argument parser (based on map)
46 luci::CircleOptimizer optimizer;
48 auto options = optimizer.options();
49 auto settings = luci::UserSettings::settings();
51 arser::Arser arser("circle2circle provides circle model optimization and transformations");
53 arser.add_argument("--version")
57 .help("Show version information and exit")
58 .exit_with(print_version);
60 arser.add_argument("--all").nargs(0).required(false).default_value(false).help(
61 "Enable all optimize options");
63 arser.add_argument("--fold_dequantize")
67 .help("This will fold dequantize op");
69 arser.add_argument("--fuse_activation_function")
73 .help("This will fuse Activation function to a preceding operator");
75 arser.add_argument("--fuse_add_with_tconv")
79 .help("This will fuse Add operator to Transposed Convolution operator");
81 arser.add_argument("--fuse_batchnorm_with_tconv")
85 .help("This will fuse BatchNorm operators to Transposed Convolution operator");
87 arser.add_argument("--fuse_bcq")
91 .help("This will fuse operators and apply Binary Coded Quantization");
93 arser.add_argument("--fuse_instnorm")
97 .help("This will fuse operators to InstanceNorm operator");
99 arser.add_argument("--make_batchnorm_gamma_positive")
102 .default_value(false)
103 .help("This will make negative gamma of BatchNorm into a small positive value (1e-10). Note "
104 "that this pass can change the execution result of the model. So, use it only when the "
105 "impact is known to be acceptable.");
107 arser.add_argument("--fuse_preactivation_batchnorm")
110 .default_value(false)
111 .help("This will fuse BatchNorm operators of pre-activations to Convolution operator");
113 arser.add_argument("--remove_redundant_transpose")
116 .default_value(false)
117 .help("This will fuse or remove subsequent Transpose operators");
119 arser.add_argument("--replace_cw_mul_add_with_depthwise_conv")
122 .default_value(false)
123 .help("This will replace channel-wise mul/add with DepthwiseConv2D operator");
125 arser.add_argument("--resolve_customop_add")
128 .default_value(false)
129 .help("This will convert Custom(Add) to Add operator");
131 arser.add_argument("--resolve_customop_batchmatmul")
134 .default_value(false)
135 .help("This will convert Custom(BatchMatmul) to BatchMatmul operator");
137 arser.add_argument("--resolve_customop_matmul")
140 .default_value(false)
141 .help("This will convert Custom(Matmul) to Matmul operator");
143 arser.add_argument("--shuffle_weight_to_16x1float32")
146 .default_value(false)
147 .help("This will convert weight format of FullyConnected to SHUFFLED16x1FLOAT32. Note that "
148 "it only converts weights whose row is a multiple of 16");
150 arser.add_argument("--substitute_pack_to_reshape")
153 .default_value(false)
154 .help("This will convert single input Pack to Reshape");
156 arser.add_argument("--mute_warnings")
159 .default_value(false)
160 .help("This will turn off warning messages");
162 arser.add_argument("--disable_validation")
165 .default_value(false)
166 .help("This will turn off operator validations. May help input model investigation.");
168 arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
169 arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
171 // sparsification argument
172 arser.add_argument("--sparsify_tensor")
174 .type(arser::DataType::STR)
176 .help("Tensor name that you want to sparsify");
178 arser.add_argument("--sparsify_traversal_order")
180 .type(arser::DataType::STR)
182 .default_value("0,1,2,3")
183 .help("Traversal order of dimensions. Default value: 0,1,2,3");
185 arser.add_argument("--sparsify_format")
187 .type(arser::DataType::STR)
189 .default_value("d,s")
190 .help("Format of each dimension. 'd' stands for dense, 's' stands for sparse(CSR). Default "
193 arser.add_argument("--sparsify_block_size")
195 .type(arser::DataType::STR)
197 .help("Size of each block dimension");
199 arser.add_argument("--sparsify_block_map")
201 .type(arser::DataType::STR)
203 .default_value("0,1")
204 .help("Map from block dimension to the original tensor dimension. Default value: 0,1");
208 arser.parse(argc, argv);
210 catch (const std::runtime_error &err)
212 std::cout << err.what() << std::endl;
217 if (arser.get<bool>("--all"))
219 options->enable(Algorithms::FuseBCQ);
220 options->enable(Algorithms::FuseInstanceNorm);
221 options->enable(Algorithms::ResolveCustomOpAdd);
222 options->enable(Algorithms::ResolveCustomOpBatchMatMul);
223 options->enable(Algorithms::ResolveCustomOpMatMul);
224 options->enable(Algorithms::RemoveRedundantTranspose);
225 options->enable(Algorithms::SubstitutePackToReshape);
227 if (arser.get<bool>("--fold_dequantize"))
228 options->enable(Algorithms::FoldDequantize);
229 if (arser.get<bool>("--fuse_activation_function"))
230 options->enable(Algorithms::FuseActivationFunction);
231 if (arser.get<bool>("--fuse_add_with_tconv"))
232 options->enable(Algorithms::FuseAddWithTConv);
233 if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
234 options->enable(Algorithms::FuseBatchNormWithTConv);
235 if (arser.get<bool>("--fuse_bcq"))
236 options->enable(Algorithms::FuseBCQ);
237 if (arser.get<bool>("--fuse_instnorm"))
238 options->enable(Algorithms::FuseInstanceNorm);
239 if (arser.get<bool>("--make_batchnorm_gamma_positive"))
240 options->enable(Algorithms::MakeBatchNormGammaPositive);
241 if (arser.get<bool>("--fuse_preactivation_batchnorm"))
242 options->enable(Algorithms::FusePreActivationBatchNorm);
243 if (arser.get<bool>("--remove_redundant_transpose"))
244 options->enable(Algorithms::RemoveRedundantTranspose);
245 if (arser.get<bool>("--replace_cw_mul_add_with_depthwise_conv"))
246 options->enable(Algorithms::ReplaceMulAddWithDepthwiseConv);
247 if (arser.get<bool>("--resolve_customop_add"))
248 options->enable(Algorithms::ResolveCustomOpAdd);
249 if (arser.get<bool>("--resolve_customop_batchmatmul"))
250 options->enable(Algorithms::ResolveCustomOpBatchMatMul);
251 if (arser.get<bool>("--resolve_customop_matmul"))
252 options->enable(Algorithms::ResolveCustomOpMatMul);
253 if (arser.get<bool>("--shuffle_weight_to_16x1float32"))
254 options->enable(Algorithms::ShuffleWeightTo16x1Float32);
255 if (arser.get<bool>("--substitute_pack_to_reshape"))
256 options->enable(Algorithms::SubstitutePackToReshape);
258 if (arser.get<bool>("--mute_warnings"))
259 settings->set(luci::UserSettings::Key::MuteWarnings, true);
260 if (arser.get<bool>("--disable_validation"))
261 settings->set(luci::UserSettings::Key::DisableValidation, true);
263 std::string input_path = arser.get<std::string>("input");
264 std::string output_path = arser.get<std::string>("output");
266 if (arser["--sparsify_tensor"])
268 options->enable(Algorithms::SparsifyTensorPass);
269 options->param(AlgorithmParameters::Sparsify_tensor_name,
270 arser.get<std::string>("--sparsify_tensor"));
271 options->param(AlgorithmParameters::Sparsify_traversal_order,
272 arser.get<std::string>("--sparsify_traversal_order"));
273 options->param(AlgorithmParameters::Sparsify_format,
274 arser.get<std::string>("--sparsify_format"));
275 if (arser["--sparsify_block_size"])
276 options->param(AlgorithmParameters::Sparsify_block_size,
277 arser.get<std::string>("--sparsify_block_size"));
280 std::cerr << "ERROR: Block size not provided" << std::endl;
283 options->param(AlgorithmParameters::Sparsify_block_map,
284 arser.get<std::string>("--sparsify_block_map"));
287 // Load model from the file
288 foder::FileLoader file_loader{input_path};
289 std::vector<char> model_data;
293 model_data = file_loader.load();
295 catch (const std::runtime_error &err)
297 std::cerr << err.what() << std::endl;
301 flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
302 if (!circle::VerifyModelBuffer(verifier))
304 std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
308 const circle::Model *circle_model = circle::GetModel(model_data.data());
309 if (circle_model == nullptr)
311 std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
315 // Import from input Circle file
316 luci::Importer importer;
317 auto module = importer.importModule(circle_model);
319 // call luci optimizations for module
320 optimizer.optimize(module.get());
322 for (size_t idx = 0; idx < module->size(); ++idx)
324 auto graph = module->graph(idx);
326 // call luci optimizations for graph
327 optimizer.optimize(graph);
328 optimizer.sparsify(graph);
330 if (!luci::validate(graph))
332 if (settings->get(luci::UserSettings::Key::DisableValidation))
333 std::cerr << "WARNING: Optimized graph is invalid" << std::endl;
336 std::cerr << "ERROR: Optimized graph is invalid" << std::endl;
342 // Export to output Circle file
343 luci::CircleExporter exporter;
345 luci::CircleFileExpContract contract(module.get(), output_path);
347 if (!exporter.invoke(&contract))
349 std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;