2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <foder/FileLoader.h>
19 #include <luci/Importer.h>
20 #include <luci/CircleOptimizer.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
25 #include <oops/InternalExn.h>
26 #include <arser/arser.h>
27 #include <vconone/vconone.h>
34 using OptionHook = std::function<int(const char **)>;
36 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
37 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
39 void print_exclusive_options(void)
41 std::cout << "Use only one of the 3 options below." << std::endl;
42 std::cout << " --quantize_dequantize_weights" << std::endl;
43 std::cout << " --quantize_with_minmax" << std::endl;
44 std::cout << " --requantize" << std::endl;
47 void print_version(void)
49 std::cout << "circle-quantizer version " << vconone::get_string() << std::endl;
50 std::cout << vconone::get_copyright() << std::endl;
53 int entry(int argc, char **argv)
55 // Simple argument parser (based on map)
56 std::map<std::string, OptionHook> argparse;
57 luci::CircleOptimizer optimizer;
59 auto options = optimizer.options();
61 const std::string qdqw = "--quantize_dequantize_weights";
62 const std::string qwmm = "--quantize_with_minmax";
63 const std::string rq = "--requantize";
65 arser::Arser arser("circle-quantizer provides circle model quantization");
67 arser.add_argument("--version")
71 .help("Show version information and exit")
72 .exit_with(print_version);
74 arser.add_argument(qdqw)
76 .type(arser::DataType::STR_VEC)
78 .help("Quantize-dequantize weight values required action before quantization. "
79 "Three arguments required: input_dtype(float32) "
80 "output_dtype(uint8) granularity(layer, channel)");
82 arser.add_argument(qwmm)
84 .type(arser::DataType::STR_VEC)
86 .help("Quantize with min/max values. "
87 "Three arguments required: input_dtype(float32) "
88 "output_dtype(uint8) granularity(layer, channel)");
90 arser.add_argument(rq)
92 .type(arser::DataType::STR_VEC)
94 .help("Requantize a quantized model. "
95 "Two arguments required: input_dtype(int8) "
96 "output_dtype(uint8)");
98 arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
99 arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
103 arser.parse(argc, argv);
105 catch (const std::runtime_error &err)
107 std::cout << err.what() << std::endl;
114 if (arser[qwmm] || arser[rq])
116 print_exclusive_options();
119 auto values = arser.get<std::vector<std::string>>(qdqw);
120 if (values.size() != 3)
125 options->enable(Algorithms::QuantizeDequantizeWeights);
127 options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
128 options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
129 options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
134 if (arser[qdqw] || arser[rq])
136 print_exclusive_options();
139 auto values = arser.get<std::vector<std::string>>(qwmm);
140 if (values.size() != 3)
145 options->enable(Algorithms::QuantizeWithMinMax);
147 options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
148 options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
149 options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
154 if (arser[qwmm] || arser[qdqw])
156 print_exclusive_options();
159 auto values = arser.get<std::vector<std::string>>(rq);
160 if (values.size() != 2)
165 options->enable(Algorithms::Requantize);
167 options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
168 options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
171 std::string input_path = arser.get<std::string>("input");
172 std::string output_path = arser.get<std::string>("output");
174 // Load model from the file
175 foder::FileLoader file_loader{input_path};
176 std::vector<char> model_data = file_loader.load();
178 // Verify flatbuffers
179 flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
180 if (!circle::VerifyModelBuffer(verifier))
182 std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
186 const circle::Model *circle_model = circle::GetModel(model_data.data());
187 if (circle_model == nullptr)
189 std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
193 // Import from input Circle file
194 luci::Importer importer;
195 auto module = importer.importModule(circle_model);
197 for (size_t idx = 0; idx < module->size(); ++idx)
199 auto graph = module->graph(idx);
201 // quantize the graph
202 optimizer.quantize(graph);
204 if (!luci::validate(graph))
206 std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
211 // Export to output Circle file
212 luci::CircleExporter exporter;
214 luci::CircleFileExpContract contract(module.get(), output_path);
216 if (!exporter.invoke(&contract))
218 std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;