2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "BisectionSolver.h"
18 #include "DepthParameterizer.h"
19 #include "ErrorMetric.h"
20 #include "ErrorApproximator.h"
22 #include <luci/ImporterEx.h>
28 using namespace mpqsolver::bisection;
33 bool error_at_input_is_larger_than_at_output(const NodeDepthType &nodes_depth, float cut_depth)
37 float error_at_input = 0;
38 float error_at_output = 0;
39 for (auto &iter : nodes_depth)
41 float cur_error = approximate(iter.first);
42 if (iter.second < cut_depth)
44 error_at_input += cur_error;
48 error_at_output += cur_error;
52 if (error_at_input > error_at_output)
54 VERBOSE(l, 0) << "Q16 will be set at input due to ";
58 VERBOSE(l, 0) << "Q8 will be set at input due to ";
60 VERBOSE(l, 0) << error_at_input << " error at input vs ";
61 VERBOSE(l, 0) << error_at_output << " error at output." << std::endl;
63 return error_at_input > error_at_output;
66 std::unique_ptr<luci::Module> read_module(const std::string &path)
68 luci::ImporterEx importerex;
69 auto module = importerex.importVerifyModule(path);
70 if (module.get() == nullptr)
72 std::cerr << "ERROR: Failed to load " << path << std::endl;
81 BisectionSolver::BisectionSolver(const std::string &input_data_path, float qerror_ratio,
82 const std::string &input_quantization,
83 const std::string &output_quantization)
84 : MPQSolver(input_data_path, qerror_ratio, input_quantization, output_quantization)
86 _quantizer = std::make_unique<Quantizer>(_input_quantization, _output_quantization);
89 float BisectionSolver::evaluate(const DatasetEvaluator &evaluator, const std::string &flt_path,
90 const std::string &def_quant, LayerParams &layers)
92 auto model = read_module(flt_path);
93 // get fake quantized model for evaluation
94 if (!_quantizer->fake_quantize(model.get(), def_quant, layers))
96 throw std::runtime_error("Failed to produce fake-quantized model.");
99 return evaluator.evaluate(model.get());
102 void BisectionSolver::algorithm(Algorithm algorithm) { _algorithm = algorithm; }
104 std::unique_ptr<luci::Module> BisectionSolver::run(const std::string &module_path)
108 auto module = read_module(module_path);
110 float min_depth = 0.f;
111 float max_depth = 0.f;
112 NodeDepthType nodes_depth;
113 if (compute_depth(module.get(), nodes_depth, min_depth, max_depth) !=
114 ParameterizerResult::SUCCESS)
116 std::cerr << "ERROR: Invalid graph for bisectioning" << std::endl;
120 std::unique_ptr<MAEMetric> metric = std::make_unique<MAEMetric>();
121 DatasetEvaluator evaluator(module.get(), _input_data_path, *metric.get());
123 LayerParams layer_params;
125 evaluate(evaluator, module_path, "int16" /* default quant_dtype */, layer_params);
126 VERBOSE(l, 0) << "Full int16 model quantization error " << int16_qerror << std::endl;
129 evaluate(evaluator, module_path, "uint8" /* default quant_dtype */, layer_params);
130 VERBOSE(l, 0) << "Full uint8 model quantization error " << uint8_qerror << std::endl;
132 if (int16_qerror > uint8_qerror)
134 throw std::runtime_error("Q8 model's qerror is less than Q16 model's qerror.");
137 _qerror = int16_qerror + _qerror_ratio * std::fabs(uint8_qerror - int16_qerror);
138 VERBOSE(l, 0) << "Target quantization error " << _qerror << std::endl;
140 if (uint8_qerror <= _qerror)
142 // no need for bisectioning just return Q8 model
143 if (!_quantizer->quantize(module.get(), "uint8", layer_params))
145 std::cerr << "ERROR: Failed to quantize model" << std::endl;
151 float best_depth = -1;
152 LayerParams best_params;
153 if (module->size() != 1)
155 throw std::runtime_error("Unsupported module");
157 auto graph = module->graph(0);
158 auto active_nodes = loco::active_nodes(loco::output_nodes(graph));
159 // input and output nodes are not valid for quantization, so let's remove them
160 for (auto node : loco::input_nodes(graph))
162 active_nodes.erase(node);
164 for (auto node : loco::output_nodes(graph))
166 active_nodes.erase(node);
169 // let's decide whether nodes at input are more suspectible to be quantized into Q16, than at
171 bool int16_front = true;
174 case Algorithm::Auto:
176 error_at_input_is_larger_than_at_output(nodes_depth, 0.5f * (max_depth + min_depth));
178 case Algorithm::ForceQ16Front:
181 case Algorithm::ForceQ16Back:
188 int cut_depth = static_cast<int>(std::floor(0.5f * (min_depth + max_depth)));
190 if (last_depth == cut_depth)
194 last_depth = cut_depth;
196 LayerParams layer_params;
197 for (auto &node : active_nodes)
199 auto cur_node = loco::must_cast<luci::CircleNode *>(node);
200 auto iter = nodes_depth.find(cur_node);
201 if (iter == nodes_depth.end())
203 continue; // to filter out nodes like weights
206 float depth = iter->second;
208 if ((depth <= cut_depth && int16_front) || (depth >= cut_depth && !int16_front))
210 auto layer_param = std::make_shared<LayerParam>();
212 layer_param->name = cur_node->name();
213 layer_param->dtype = "int16";
214 layer_param->granularity = "channel";
217 layer_params.emplace_back(layer_param);
221 float cur_accuracy = evaluate(evaluator, module_path, "uint8", layer_params);
222 VERBOSE(l, 0) << cut_depth << " : " << cur_accuracy << std::endl;
224 if (cur_accuracy < _qerror)
226 int16_front ? (max_depth = cut_depth) : (min_depth = cut_depth);
227 best_params = layer_params;
228 best_depth = cut_depth;
232 int16_front ? (min_depth = cut_depth) : (max_depth = cut_depth);
236 VERBOSE(l, 0) << "Found the best configuration at " << best_depth << " depth." << std::endl;
237 if (!_quantizer->quantize(module.get(), "uint8", best_params))
239 std::cerr << "ERROR: Failed to quantize model" << std::endl;