2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <luci/ImporterEx.h>
18 #include <luci/CircleQuantizer.h>
19 #include <luci/Service/Validate.h>
20 #include <luci/CircleExporter.h>
21 #include <luci/CircleFileExpContract.h>
22 #include <luci/UserSettings.h>
24 #include <oops/InternalExn.h>
25 #include <arser/arser.h>
26 #include <vconone/vconone.h>
34 using OptionHook = std::function<int(const char **)>;
36 using LayerParam = luci::CircleQuantizer::Options::LayerParam;
37 using Algorithms = luci::CircleQuantizer::Options::Algorithm;
38 using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
40 std::vector<std::shared_ptr<LayerParam>> read_layer_params(std::string &filename)
43 std::ifstream ifs(filename);
45 // Failed to open cfg file
46 if (not ifs.is_open())
47 throw std::runtime_error("Cannot open config file. " + filename);
49 Json::CharReaderBuilder builder;
53 if (not parseFromStream(builder, ifs, &root, &errs))
54 throw std::runtime_error("Cannot parse config file (json format). " + errs);
56 auto layers = root["layers"];
57 std::vector<std::shared_ptr<LayerParam>> p;
58 for (auto layer : layers)
60 if (layer.isMember("name"))
62 auto l = std::make_shared<LayerParam>();
64 l->name = layer["name"].asString();
65 l->dtype = layer["dtype"].asString();
66 l->granularity = layer["granularity"].asString();
71 // Multiple names with the same dtype & granularity
72 if (layer.isMember("names"))
74 for (auto name : layer["names"])
76 auto l = std::make_shared<LayerParam>();
78 l->name = name.asString();
79 l->dtype = layer["dtype"].asString();
80 l->granularity = layer["granularity"].asString();
90 void print_exclusive_options(void)
92 std::cout << "Use only one of the 3 options below." << std::endl;
93 std::cout << " --quantize_dequantize_weights" << std::endl;
94 std::cout << " --quantize_with_minmax" << std::endl;
95 std::cout << " --requantize" << std::endl;
96 std::cout << " --force_quantparam" << std::endl;
99 void print_version(void)
101 std::cout << "circle-quantizer version " << vconone::get_string() << std::endl;
102 std::cout << vconone::get_copyright() << std::endl;
105 int entry(int argc, char **argv)
107 luci::CircleQuantizer quantizer;
109 auto options = quantizer.options();
110 auto settings = luci::UserSettings::settings();
112 const std::string qdqw = "--quantize_dequantize_weights";
113 const std::string qwmm = "--quantize_with_minmax";
114 const std::string rq = "--requantize";
115 const std::string fq = "--force_quantparam";
116 const std::string cq = "--copy_quantparam";
117 const std::string fake_quant = "--fake_quantize";
118 const std::string cfg = "--config";
120 const std::string tf_maxpool = "--TF-style_maxpool";
122 const std::string gpd = "--generate_profile_data";
124 arser::Arser arser("circle-quantizer provides circle model quantization");
126 arser::Helper::add_version(arser, print_version);
127 arser::Helper::add_verbose(arser);
129 arser.add_argument(qdqw)
131 .type(arser::DataType::STR_VEC)
132 .help("Quantize-dequantize weight values required action before quantization. "
133 "Three arguments required: input_model_dtype(float32) "
134 "output_model_dtype(uint8) granularity(layer, channel)");
136 arser.add_argument(qwmm)
138 .type(arser::DataType::STR_VEC)
139 .help("Quantize with min/max values. "
140 "Three arguments required: input_model_dtype(float32) "
141 "output_model_dtype(uint8) granularity(layer, channel)");
143 arser.add_argument(tf_maxpool)
145 .default_value(false)
146 .help("Force MaxPool Op to have the same input/output quantparams. NOTE: This feature can "
147 "degrade accuracy of some models");
149 arser.add_argument(fake_quant)
151 .help("Convert a quantized model to a fake-quantized model. NOTE: This feature will "
152 "generate an fp32 model.");
154 arser.add_argument(rq)
156 .type(arser::DataType::STR_VEC)
157 .help("Requantize a quantized model. "
158 "Two arguments required: input_model_dtype(int8) "
159 "output_model_dtype(uint8)");
161 arser.add_argument(fq)
163 .type(arser::DataType::STR_VEC)
165 .help("Write quantization parameters to the specified tensor. "
166 "Three arguments required: tensor_name(string), "
167 "scale(float) zero_point(int)");
169 arser.add_argument(cq)
171 .type(arser::DataType::STR_VEC)
173 .help("Copy quantization parameter from a tensor to another tensor."
174 "Two arguments required: source_tensor_name(string), "
175 "destination_tensor_name(string)");
177 arser.add_argument("--input_type")
178 .help("Input type of quantized model (uint8, int16, int32, int64, float32, or bool). For "
180 "use comma-separated values. e.g., uint8,int16");
182 arser.add_argument("--output_type")
183 .help("Output type of quantized model (uint8, int16, int32, int64, float32, or bool). For "
185 "use comma-separated values. e.g., uint8,int16");
187 arser.add_argument(cfg).help("Path to the quantization configuration file");
189 arser.add_argument("input").help("Input circle model");
190 arser.add_argument("output").help("Output circle model");
192 arser.add_argument(gpd).nargs(0).required(false).default_value(false).help(
193 "This will turn on profiling data generation.");
197 arser.parse(argc, argv);
199 catch (const std::runtime_error &err)
201 std::cerr << err.what() << std::endl;
207 // only one of qdqw, qwmm, rq, fq, cq, fake_quant option can be used
208 int32_t opt_used = arser[qdqw] ? 1 : 0;
209 opt_used += arser[qwmm] ? 1 : 0;
210 opt_used += arser[rq] ? 1 : 0;
211 opt_used += arser[fq] ? 1 : 0;
212 opt_used += arser[cq] ? 1 : 0;
213 opt_used += arser[fake_quant] ? 1 : 0;
216 print_exclusive_options();
221 if (arser.get<bool>("--verbose"))
223 // The third parameter of setenv means REPLACE.
224 // If REPLACE is zero, it does not overwrite an existing value.
225 setenv("LUCI_LOG", "100", 0);
230 auto values = arser.get<std::vector<std::string>>(qdqw);
231 if (values.size() != 3)
236 options->enable(Algorithms::QuantizeDequantizeWeights);
238 options->param(AlgorithmParameters::Quantize_input_model_dtype, values.at(0));
239 options->param(AlgorithmParameters::Quantize_output_model_dtype, values.at(1));
240 options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
244 auto filename = arser.get<std::string>(cfg);
247 auto layer_params = read_layer_params(filename);
249 options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params);
251 catch (const std::runtime_error &e)
253 std::cerr << e.what() << '\n';
261 auto values = arser.get<std::vector<std::string>>(qwmm);
262 if (values.size() != 3)
267 options->enable(Algorithms::QuantizeWithMinMax);
269 options->param(AlgorithmParameters::Quantize_input_model_dtype, values.at(0));
270 options->param(AlgorithmParameters::Quantize_output_model_dtype, values.at(1));
271 options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
273 if (arser["--input_type"])
274 options->param(AlgorithmParameters::Quantize_input_type,
275 arser.get<std::string>("--input_type"));
277 if (arser["--output_type"])
278 options->param(AlgorithmParameters::Quantize_output_type,
279 arser.get<std::string>("--output_type"));
281 if (arser[tf_maxpool] and arser.get<bool>(tf_maxpool))
282 options->param(AlgorithmParameters::Quantize_TF_style_maxpool, "True");
286 auto filename = arser.get<std::string>(cfg);
289 auto layer_params = read_layer_params(filename);
291 options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params);
293 catch (const std::runtime_error &e)
295 std::cerr << e.what() << '\n';
303 auto values = arser.get<std::vector<std::string>>(rq);
304 if (values.size() != 2)
309 options->enable(Algorithms::Requantize);
311 options->param(AlgorithmParameters::Quantize_input_model_dtype, values.at(0));
312 options->param(AlgorithmParameters::Quantize_output_model_dtype, values.at(1));
317 auto values = arser.get<std::vector<std::vector<std::string>>>(fq);
319 std::vector<std::string> tensors;
320 std::vector<std::string> scales;
321 std::vector<std::string> zero_points;
323 for (auto const value : values)
325 if (value.size() != 3)
331 tensors.push_back(value[0]);
332 scales.push_back(value[1]);
333 zero_points.push_back(value[2]);
336 options->enable(Algorithms::ForceQuantParam);
338 options->params(AlgorithmParameters::Quantize_tensor_names, tensors);
339 options->params(AlgorithmParameters::Quantize_scales, scales);
340 options->params(AlgorithmParameters::Quantize_zero_points, zero_points);
345 auto values = arser.get<std::vector<std::vector<std::string>>>(cq);
347 std::vector<std::string> src;
348 std::vector<std::string> dst;
350 for (auto const value : values)
352 if (value.size() != 2)
358 src.push_back(value[0]);
359 dst.push_back(value[1]);
362 options->enable(Algorithms::CopyQuantParam);
364 options->params(AlgorithmParameters::Quantize_src_tensor_names, src);
365 options->params(AlgorithmParameters::Quantize_dst_tensor_names, dst);
368 if (arser[fake_quant])
369 options->enable(Algorithms::ConvertToFakeQuantizedModel);
371 std::string input_path = arser.get<std::string>("input");
372 std::string output_path = arser.get<std::string>("output");
375 settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
377 // Load model from the file
378 luci::ImporterEx importerex;
379 auto module = importerex.importVerifyModule(input_path);
380 if (module.get() == nullptr)
383 for (size_t idx = 0; idx < module->size(); ++idx)
385 auto graph = module->graph(idx);
387 // quantize the graph
388 quantizer.quantize(graph);
390 if (!luci::validate(graph))
392 std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
397 // Export to output Circle file
398 luci::CircleExporter exporter;
400 luci::CircleFileExpContract contract(module.get(), output_path);
402 if (!exporter.invoke(&contract))
404 std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;