2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <foder/FileLoader.h>
19 #include <luci/Importer.h>
20 #include <luci/CircleOptimizer.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24 #include <luci/UserSettings.h>
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
34 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
35 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
37 void print_version(void)
39 std::cout << "circle2circle version " << vconone::get_string() << std::endl;
40 std::cout << vconone::get_copyright() << std::endl;
43 int entry(int argc, char **argv)
45 // Simple argument parser (based on map)
46 luci::CircleOptimizer optimizer;
48 auto options = optimizer.options();
49 auto settings = luci::UserSettings::settings();
51 arser::Arser arser("circle2circle provides circle model optimization and transformations");
53 arser.add_argument("--version")
57 .help("Show version information and exit")
58 .exit_with(print_version);
60 arser.add_argument("--all").nargs(0).required(false).default_value(false).help(
61 "Enable all optimize options");
63 arser.add_argument("--fuse_batchnorm_with_tconv")
67 .help("This will fuse BatchNorm operators to Transposed Convolution operator");
69 arser.add_argument("--fuse_bcq")
73 .help("This will fuse operators and apply Binary Coded Quantization");
75 arser.add_argument("--fuse_instnorm")
79 .help("This will fuse operators to InstanceNorm operator");
81 arser.add_argument("--resolve_customop_add")
85 .help("This will convert Custom(Add) to Add operator");
87 arser.add_argument("--resolve_customop_batchmatmul")
91 .help("This will convert Custom(BatchMatmul) to BatchMatmul operator");
93 arser.add_argument("--resolve_customop_matmul")
97 .help("This will convert Custom(Matmul) to Matmul operator");
99 arser.add_argument("--mute_warnings")
102 .default_value(false)
103 .help("This will turn off warning messages");
105 arser.add_argument("--disable_validation")
108 .default_value(false)
109 .help("This will turn off operator validations. May help input model investigation.");
111 arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
112 arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
116 arser.parse(argc, argv);
118 catch (const std::runtime_error &err)
120 std::cout << err.what() << std::endl;
125 if (arser.get<bool>("--all"))
127 options->enable(Algorithms::FuseBCQ);
128 options->enable(Algorithms::FuseInstanceNorm);
129 options->enable(Algorithms::ResolveCustomOpAdd);
130 options->enable(Algorithms::ResolveCustomOpBatchMatMul);
131 options->enable(Algorithms::ResolveCustomOpMatMul);
133 if (arser.get<bool>("--fuse_batchnorm_with_tconv"))
134 options->enable(Algorithms::FuseBatchNormWithTConv);
135 if (arser.get<bool>("--fuse_bcq"))
136 options->enable(Algorithms::FuseBCQ);
137 if (arser.get<bool>("--fuse_instnorm"))
138 options->enable(Algorithms::FuseInstanceNorm);
139 if (arser.get<bool>("--resolve_customop_add"))
140 options->enable(Algorithms::ResolveCustomOpAdd);
141 if (arser.get<bool>("--resolve_customop_batchmatmul"))
142 options->enable(Algorithms::ResolveCustomOpBatchMatMul);
143 if (arser.get<bool>("--resolve_customop_matmul"))
144 options->enable(Algorithms::ResolveCustomOpMatMul);
146 if (arser.get<bool>("--mute_warnings"))
147 settings->set(luci::UserSettings::Key::MuteWarnings, true);
148 if (arser.get<bool>("--disable_validation"))
149 settings->set(luci::UserSettings::Key::DisableValidation, true);
151 std::string input_path = arser.get<std::string>("input");
152 std::string output_path = arser.get<std::string>("output");
154 // Load model from the file
155 foder::FileLoader file_loader{input_path};
156 std::vector<char> model_data;
160 model_data = file_loader.load();
162 catch (const std::runtime_error &err)
164 std::cerr << err.what() << std::endl;
168 flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
169 if (!circle::VerifyModelBuffer(verifier))
171 std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
175 const circle::Model *circle_model = circle::GetModel(model_data.data());
176 if (circle_model == nullptr)
178 std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
182 // Import from input Circle file
183 luci::Importer importer;
184 auto module = importer.importModule(circle_model);
186 for (size_t idx = 0; idx < module->size(); ++idx)
188 auto graph = module->graph(idx);
190 // call luci optimizations
191 optimizer.optimize(graph);
193 if (!luci::validate(graph))
195 if (settings->get(luci::UserSettings::Key::DisableValidation))
196 std::cerr << "WARNING: Optimized graph is invalid" << std::endl;
199 std::cerr << "ERROR: Optimized graph is invalid" << std::endl;
205 // Export to output Circle file
206 luci::CircleExporter exporter;
208 luci::CircleFileExpContract contract(module.get(), output_path);
210 if (!exporter.invoke(&contract))
212 std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;