2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <arser/arser.h>
18 #include <vconone/vconone.h>
19 #include <luci/CircleExporter.h>
20 #include <luci/CircleFileExpContract.h>
23 #include "bisection/BisectionSolver.h"
29 void print_version(void)
31 std::cout << "circle-mpqsolver version " << vconone::get_string() << std::endl;
32 std::cout << vconone::get_copyright() << std::endl;
35 int entry(int argc, char **argv)
39 const std::string bisection_str = "--bisection";
41 arser::Arser arser("circle-mpqsolver provides light-weight methods for finding a high-quality "
42 "mixed-precision model within a reasonable time.");
44 arser::Helper::add_version(arser, print_version);
45 arser::Helper::add_verbose(arser);
47 arser.add_argument("--data").required(true).help("Path to the test data");
48 arser.add_argument("--data_format").required(false).help("Test data format (default: h5)");
50 arser.add_argument("--qerror_ratio")
51 .type(arser::DataType::FLOAT)
53 .help("quantization error ratio ([0, 1])");
55 arser.add_argument(bisection_str)
57 .type(arser::DataType::STR)
58 .help("Single optional argument for bisection method. "
59 "Whether input node should be quantized to Q16: 'auto', 'true', 'false'.");
61 arser.add_argument("--input_model")
63 .help("Input float model with min max initialized");
65 arser.add_argument("--input_dtype")
66 .type(arser::DataType::STR)
67 .default_value("uint8")
68 .help("Data type of quantized model's inputs (default: uint8)");
70 arser.add_argument("--output_dtype")
71 .type(arser::DataType::STR)
72 .default_value("uint8")
73 .help("Data type of quantized model's outputs (default: uint8)");
75 arser.add_argument("--output_model").required(true).help("Output quantized model");
79 arser.parse(argc, argv);
81 catch (const std::runtime_error &err)
83 std::cerr << err.what() << std::endl;
88 if (arser.get<bool>("--verbose"))
90 // The third parameter of setenv means REPLACE.
91 // If REPLACE is zero, it does not overwrite an existing value.
92 setenv("LUCI_LOG", "100", 0);
95 auto data_path = arser.get<std::string>("--data");
96 auto input_model_path = arser.get<std::string>("--input_model");
97 auto output_model_path = arser.get<std::string>("--output_model");
98 auto input_dtype = arser.get<std::string>("--input_dtype");
99 auto output_dtype = arser.get<std::string>("--output_dtype");
101 float qerror_ratio = arser.get<float>("--qerror_ratio");
102 if (qerror_ratio < 0.0 || qerror_ratio > 1.f)
104 std::cerr << "ERROR: quantization ratio must be in [0, 1]" << std::endl;
107 auto start = std::chrono::high_resolution_clock::now();
109 if (arser[bisection_str])
112 using namespace mpqsolver::bisection;
114 BisectionSolver solver(data_path, qerror_ratio, input_dtype, output_dtype);
116 auto value = arser.get<std::string>(bisection_str);
119 solver.algorithm(BisectionSolver::Algorithm::Auto);
121 else if (value == "true")
123 solver.algorithm(BisectionSolver::Algorithm::ForceQ16Front);
125 else if (value == "false")
127 solver.algorithm(BisectionSolver::Algorithm::ForceQ16Back);
131 std::cerr << "ERROR: Unrecognized option for bisection algortithm" << input_model_path
137 auto optimized = solver.run(input_model_path);
138 if (optimized == nullptr)
140 std::cerr << "ERROR: Failed to build mixed precision model" << input_model_path << std::endl;
146 luci::CircleExporter exporter;
147 luci::CircleFileExpContract contract(optimized.get(), output_model_path);
148 if (!exporter.invoke(&contract))
150 std::cerr << "ERROR: Failed to export mixed precision model" << input_model_path
158 std::cerr << "ERROR: Unrecognized solver" << std::endl;
162 auto duration = std::chrono::duration_cast<std::chrono::seconds>(
163 std::chrono::high_resolution_clock::now() - start);
164 VERBOSE(l, 0) << "Elapsed Time: " << std::setprecision(5) << duration.count() / 60.f
165 << " minutes." << std::endl;