2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <foder/FileLoader.h>
19 #include <luci/Importer.h>
20 #include <luci/CircleOptimizer.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24 #include <luci/UserSettings.h>
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
35 using OptionHook = std::function<int(const char **)>;
37 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
38 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
40 void print_exclusive_options(void)
42 std::cout << "Use only one of the 3 options below." << std::endl;
43 std::cout << " --quantize_dequantize_weights" << std::endl;
44 std::cout << " --quantize_with_minmax" << std::endl;
45 std::cout << " --requantize" << std::endl;
46 std::cout << " --force_quantparam" << std::endl;
49 void print_version(void)
51 std::cout << "circle-quantizer version " << vconone::get_string() << std::endl;
52 std::cout << vconone::get_copyright() << std::endl;
55 int entry(int argc, char **argv)
57 // Simple argument parser (based on map)
58 std::map<std::string, OptionHook> argparse;
59 luci::CircleOptimizer optimizer;
61 auto options = optimizer.options();
62 auto settings = luci::UserSettings::settings();
64 const std::string qdqw = "--quantize_dequantize_weights";
65 const std::string qwmm = "--quantize_with_minmax";
66 const std::string rq = "--requantize";
67 const std::string fq = "--force_quantparam";
69 const std::string gpd = "--generate_profile_data";
71 arser::Arser arser("circle-quantizer provides circle model quantization");
73 arser.add_argument("--version")
77 .help("Show version information and exit")
78 .exit_with(print_version);
80 arser.add_argument("-V", "--verbose")
84 .help("output additional information to stdout or stderr");
86 arser.add_argument(qdqw)
88 .type(arser::DataType::STR_VEC)
90 .help("Quantize-dequantize weight values required action before quantization. "
91 "Three arguments required: input_dtype(float32) "
92 "output_dtype(uint8) granularity(layer, channel)");
94 arser.add_argument(qwmm)
96 .type(arser::DataType::STR_VEC)
98 .help("Quantize with min/max values. "
99 "Three arguments required: input_dtype(float32) "
100 "output_dtype(uint8) granularity(layer, channel)");
102 arser.add_argument(rq)
104 .type(arser::DataType::STR_VEC)
106 .help("Requantize a quantized model. "
107 "Two arguments required: input_dtype(int8) "
108 "output_dtype(uint8)");
110 arser.add_argument(fq)
112 .type(arser::DataType::STR_VEC)
115 .help("Write quantization parameters to the specified tensor. "
116 "Three arguments required: tensor_name(string), "
117 "scale(float) zero_point(int)");
119 arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
120 arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
122 arser.add_argument(gpd).nargs(0).required(false).default_value(false).help(
123 "This will turn on profiling data generation.");
127 arser.parse(argc, argv);
129 catch (const std::runtime_error &err)
131 std::cerr << err.what() << std::endl;
137 // only one of qdqw, qwmm, rq, fq option can be used
138 int32_t opt_used = arser[qdqw] ? 1 : 0;
139 opt_used += arser[qwmm] ? 1 : 0;
140 opt_used += arser[rq] ? 1 : 0;
141 opt_used += arser[fq] ? 1 : 0;
144 print_exclusive_options();
149 if (arser.get<bool>("--verbose"))
151 // The third parameter of setenv means REPLACE.
152 // If REPLACE is zero, it does not overwrite an existing value.
153 setenv("LUCI_LOG", "100", 0);
158 auto values = arser.get<std::vector<std::string>>(qdqw);
159 if (values.size() != 3)
164 options->enable(Algorithms::QuantizeDequantizeWeights);
166 options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
167 options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
168 options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
173 auto values = arser.get<std::vector<std::string>>(qwmm);
174 if (values.size() != 3)
179 options->enable(Algorithms::QuantizeWithMinMax);
181 options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
182 options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
183 options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
188 auto values = arser.get<std::vector<std::string>>(rq);
189 if (values.size() != 2)
194 options->enable(Algorithms::Requantize);
196 options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
197 options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
202 auto values = arser.get<std::vector<std::vector<std::string>>>(fq);
204 std::vector<std::string> tensors;
205 std::vector<std::string> scales;
206 std::vector<std::string> zero_points;
208 for (auto const value : values)
210 if (value.size() != 3)
216 tensors.push_back(value[0]);
217 scales.push_back(value[1]);
218 zero_points.push_back(value[2]);
221 options->enable(Algorithms::ForceQuantParam);
223 options->params(AlgorithmParameters::Quantize_tensor_names, tensors);
224 options->params(AlgorithmParameters::Quantize_scales, scales);
225 options->params(AlgorithmParameters::Quantize_zero_points, zero_points);
228 std::string input_path = arser.get<std::string>("input");
229 std::string output_path = arser.get<std::string>("output");
232 settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
234 // Load model from the file
235 foder::FileLoader file_loader{input_path};
236 std::vector<char> model_data = file_loader.load();
238 // Verify flatbuffers
239 flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
240 if (!circle::VerifyModelBuffer(verifier))
242 std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
246 const circle::Model *circle_model = circle::GetModel(model_data.data());
247 if (circle_model == nullptr)
249 std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
253 // Import from input Circle file
254 luci::Importer importer;
255 auto module = importer.importModule(circle_model);
257 for (size_t idx = 0; idx < module->size(); ++idx)
259 auto graph = module->graph(idx);
261 // quantize the graph
262 optimizer.quantize(graph);
264 if (!luci::validate(graph))
266 std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
271 // Export to output Circle file
272 luci::CircleExporter exporter;
274 luci::CircleFileExpContract contract(module.get(), output_path);
276 if (!exporter.invoke(&contract))
278 std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;