2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/CircleQuantizer.h"
19 #include "luci/Pass/CopyQuantParamPass.h"
20 #include "luci/Pass/ForceQuantParamPass.h"
21 #include "luci/Pass/PropagateQParamForwardPass.h"
22 #include "luci/Pass/RequantizePass.h"
23 #include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
24 #include "luci/Pass/FoldDequantizePass.h"
25 #include "luci/Pass/RemoveRedundantDequantizePass.h"
26 #include "luci/Pass/QuantizePreCheckerPass.h"
27 #include "luci/Pass/QuantizeWithMinMaxPass.h"
28 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
29 #include "luci/Pass/QuantizeWeightsPass.h"
31 #include "luci/Pass/CircleShapeInferencePass.h"
32 #include "luci/Pass/CircleTypeInferencePass.h"
35 #include <logo/RemoveDeadNodeWithQueryPass.h>
37 #include "ProgressReporter.h"
38 #include "helpers/Strings.h"
40 #include "QuantizedModelVerifier.h"
42 #include <luci/IR/CircleNode.h>
43 #include <logo/Phase.h>
44 #include <pepper/csv2vec.h>
52 using LayerParam = luci::CircleQuantizer::Options::LayerParam;
54 // This function updates user-given input_type to match with the input signature of graph
55 // If user gives only one input_type, it will be expanded to the number of graph inputs
56 void canonicalize_input_type(loco::Graph *g, std::vector<loco::DataType> &input_type)
61 const auto inputs = g->inputs();
63 assert(inputs); // FIX_CALLER_UNLESS
65 // Check validity of the number of input dtype given by a user
66 if (input_type.size() != 1 and input_type.size() != inputs->size())
68 throw std::runtime_error(
69 "Invalid number of input dtype. The number of input dtype should be 1 or "
70 "the same as the number of graph inputs.");
73 // Handle the case when a user gives only one input dtype
74 if (input_type.size() == 1)
76 const auto user_given_dtype = input_type[0];
79 // Expand input dtype to the number of graph inputs
80 // Since quantizer can only quantize float32, user_given_dtype is set only for float32 inputs
81 auto input_nodes = loco::input_nodes(g);
82 for (uint32_t i = 0; i < input_nodes.size(); i++)
84 auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]);
86 if (input->dtype() == loco::DataType::FLOAT32)
87 input_type.push_back(user_given_dtype);
89 input_type.push_back(input->dtype());
93 // Finally, check validity of input_type
94 // input_type is valid if
95 // C1. for non-float32 model input, input_type == model's input dtype
97 // C2. for float32 model input, input_type == uint8, int16, or float32
98 auto input_nodes = loco::input_nodes(g);
99 for (uint32_t i = 0; i < input_nodes.size(); i++)
101 auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]);
102 assert(i == input->index()); // FIX_ME_UNLESS
104 if (input->dtype() != loco::DataType::FLOAT32)
107 if (input->dtype() != input_type[i])
108 throw std::runtime_error(
109 "Input dtype of " + input->name() +
110 " is invalid. It has to be the same with the model's input dtype.");
115 if (input_type[i] != loco::DataType::FLOAT32 and input_type[i] != loco::DataType::U8 and
116 input_type[i] != loco::DataType::S16)
118 throw std::runtime_error("Input dtype of " + input->name() +
119 " is invalid. For float32 input, the input dtype after "
120 "quantization must be one of uint8, int16, or float32.");
126 // This function updates user-given output_type to match with the output signature of graph
127 // If user gives only one output_type, it will be expanded to the number of graph outputs
128 // NOTE This function is almost same with canonicalize_input_type, but it is written as a
129 // separate function for more precise error messaging.
130 // TODO Find a way to reduce duplicate codes
131 void canonicalize_output_type(loco::Graph *g, std::vector<loco::DataType> &output_type)
136 const auto outputs = g->outputs();
138 assert(outputs); // FIX_CALLER_UNLESS
140 // Check validity of the number of output dtype given by a user
141 if (output_type.size() != 1 and output_type.size() != outputs->size())
143 throw std::runtime_error(
144 "Invalid number of output dtype. The number of output dtype should be 1 or "
145 "the same as the number of graph outputs.");
148 // Handle the case when a user gives only one output dtype
149 if (output_type.size() == 1)
151 const auto user_given_dtype = output_type[0];
154 // Expand output dtype to the number of graph outputs
155 // If dtype of graph output is float32, it will be replaced with user_given_dtype
156 // Otherwise, it will not change
157 auto output_nodes = loco::output_nodes(g);
158 for (uint32_t i = 0; i < output_nodes.size(); i++)
160 auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]);
162 if (output->dtype() == loco::DataType::FLOAT32)
163 output_type.push_back(user_given_dtype);
165 output_type.push_back(output->dtype());
169 // Finally, check validity of output_type
170 // output_type is valid if
171 // C1. for non-float32 model output, output_type == model's output dtype
173 // C2. for float32 model output, output_type == uint8, int16, or float32
174 auto output_nodes = loco::output_nodes(g);
175 for (uint32_t i = 0; i < output_nodes.size(); i++)
177 auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]);
178 assert(i == output->index()); // FIX_ME_UNLESS
180 if (output->dtype() != loco::DataType::FLOAT32)
183 if (output->dtype() != output_type[i])
184 throw std::runtime_error(
185 "Output dtype of " + output->name() +
186 " is invalid. It has to be the same with the model's output dtype.");
191 if (output_type[i] != loco::DataType::FLOAT32 and output_type[i] != loco::DataType::U8 and
192 output_type[i] != loco::DataType::S16)
194 throw std::runtime_error("Output dtype of " + output->name() +
195 " is invalid. For float32 output, the output dtype after "
196 "quantization must be one of uint8, int16, or float32.");
202 template <typename T> T lexical_cast(const std::string &str)
204 std::istringstream ss;
211 template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
213 std::vector<T> result;
214 std::transform(sv.begin(), sv.end(), std::back_inserter(result),
215 [](std::string str) -> T { return lexical_cast<T>(str); });
219 class QuantizeOptionsImpl final : public luci::CircleQuantizer::Options
222 void enable(Algorithm) final;
223 void param(AlgorithmParameters, const std::string &) final;
224 const std::string param(AlgorithmParameters) const final;
225 void params(AlgorithmParameters, std::vector<std::string> &) final;
226 std::vector<std::string> params(AlgorithmParameters) const final;
227 void layer_params(AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>> &) final;
228 std::vector<std::shared_ptr<LayerParam>> layer_params(AlgorithmParameters) const final;
229 bool query(Algorithm) final;
232 std::vector<Algorithm> _algorithms;
233 std::map<AlgorithmParameters, const std::string> _algorithm_params;
234 std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
235 std::map<AlgorithmParameters, std::vector<std::shared_ptr<LayerParam>>> _layer_params;
238 void QuantizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
240 void QuantizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
242 _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
245 const std::string QuantizeOptionsImpl::param(AlgorithmParameters param) const
247 auto param_str = _algorithm_params.find(param);
248 if (param_str != _algorithm_params.end())
250 return param_str->second;
254 return std::string();
258 void QuantizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
260 _multiple_params[param] = vec;
263 std::vector<std::string> QuantizeOptionsImpl::params(AlgorithmParameters param) const
265 auto param_vec = _multiple_params.find(param);
266 if (param_vec != _multiple_params.end())
268 return param_vec->second;
272 return std::vector<std::string>();
276 void QuantizeOptionsImpl::layer_params(AlgorithmParameters param,
277 std::vector<std::shared_ptr<LayerParam>> &vec)
279 _layer_params[param] = vec;
282 std::vector<std::shared_ptr<LayerParam>>
283 QuantizeOptionsImpl::layer_params(AlgorithmParameters param) const
285 auto param_vec = _layer_params.find(param);
286 if (param_vec != _layer_params.end())
288 return param_vec->second;
292 return std::vector<std::shared_ptr<LayerParam>>();
296 bool QuantizeOptionsImpl::query(Algorithm algo)
298 std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
299 if (it == _algorithms.end())
310 CircleQuantizer::Options *CircleQuantizer::options(void)
312 if (_options == nullptr)
314 _options = std::make_unique<QuantizeOptionsImpl>();
317 return _options.get();
320 void CircleQuantizer::quantize(loco::Graph *g) const
322 // Fake quantization of weights
323 if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
325 static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
326 static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
327 static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
329 auto input_model_dtype =
330 _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
331 auto output_model_dtype =
332 _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
333 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
334 auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
336 if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
337 throw std::runtime_error("Unsupported input type. List of supported input type: " +
338 to_string(fakeq_supported_input_model_dtype));
340 if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
341 throw std::runtime_error("Unsupported output type. List of supported output type: " +
342 to_string(fakeq_supported_output_model_dtype));
344 if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
345 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
346 to_string(fakeq_supported_granularity));
348 if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
349 str_to_dtype(output_model_dtype) != loco::DataType::U8)
350 throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
352 // Check dtype/granularity of layer params
353 for (auto layer_param : layer_params)
355 auto name = layer_param->name;
356 if (!in_array(to_lower_case(layer_param->dtype), fakeq_supported_output_model_dtype))
358 throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
359 to_string(fakeq_supported_output_model_dtype));
361 if (!in_array(to_lower_case(layer_param->granularity), fakeq_supported_granularity))
363 throw std::runtime_error(
364 "Unsupported granularity in " + name +
365 ". List of supported granularity: " + to_string(fakeq_supported_granularity));
369 // Clear existing quantparams before doing fake quantization
370 for (auto node : loco::active_nodes(loco::output_nodes(g)))
372 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
373 if (circle_node->quantparam() != nullptr)
374 circle_node->quantparam(nullptr);
377 auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsPass::Context>();
379 ctx->input_model_dtype = str_to_dtype(input_model_dtype);
380 ctx->output_model_dtype = str_to_dtype(output_model_dtype);
381 ctx->granularity = str_to_granularity(granularity);
383 for (auto layer_param : layer_params)
387 info.name = layer_param->name;
388 info.dtype = str_to_dtype(layer_param->dtype);
389 info.granularity = str_to_granularity(layer_param->granularity);
391 ctx->layers_info.emplace_back(info);
395 luci::QuantizeDequantizeWeightsPass fake_quantizer(std::move(ctx));
397 fake_quantizer.run(g);
400 // Actual quantization of weights, bias, and activation
401 if (_options->query(Options::Algorithm::QuantizeWithMinMax))
403 static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
404 static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
405 static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
406 static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "int32",
407 "int64", "float32", "bool"};
408 static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "int32",
409 "int64", "float32", "bool"};
411 auto input_model_dtype =
412 _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
413 auto output_model_dtype =
414 _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
415 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
416 auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
417 if (input_type.empty())
418 input_type = output_model_dtype;
419 auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
420 if (output_type.empty())
421 output_type = output_model_dtype;
423 auto input_type_vec = pepper::csv_to_vector<std::string>(input_type);
424 auto output_type_vec = pepper::csv_to_vector<std::string>(output_type);
426 bool TF_style_maxpool =
427 _options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True";
429 auto layer_params = _options->layer_params(Options::AlgorithmParameters::Quantize_layer_params);
431 if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
432 throw std::runtime_error("Unsupported input type. List of supported input types: " +
433 to_string(qwmm_supported_input_model_dtype));
435 if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
436 throw std::runtime_error("Unsupported output type. List of supported output types: " +
437 to_string(qwmm_supported_output_model_dtype));
439 if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
440 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
441 to_string(qwmm_supported_granularity));
443 for (const auto &dtype : input_type_vec)
445 if (!in_array(to_lower_case(dtype), qwmm_supported_input_type))
446 throw std::runtime_error("Unsupported input type. List of supported input types: " +
447 to_string(qwmm_supported_input_type));
450 for (const auto &dtype : output_type_vec)
452 if (!in_array(to_lower_case(dtype), qwmm_supported_output_type))
453 throw std::runtime_error("Unsupported output type. List of supported output types: " +
454 to_string(qwmm_supported_output_type));
457 if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
458 str_to_dtype(output_model_dtype) != loco::DataType::U8)
459 throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
461 // Check dtype/granularity of layer params
462 for (auto layer_param : layer_params)
464 auto name = layer_param->name;
465 if (!in_array(to_lower_case(layer_param->dtype), qwmm_supported_output_model_dtype))
467 throw std::runtime_error("Unsupported dtype in " + name + ". List of supported dtype: " +
468 to_string(qwmm_supported_output_model_dtype));
470 if (!in_array(to_lower_case(layer_param->granularity), qwmm_supported_granularity))
472 throw std::runtime_error(
473 "Unsupported granularity in " + name +
474 ". List of supported granularity: " + to_string(qwmm_supported_granularity));
478 auto input_types = str_vec_to_dtype_vec(input_type_vec);
479 auto output_types = str_vec_to_dtype_vec(output_type_vec);
481 // Canonicalize user-given input/output_type (match with # of inputs/outputs)
482 canonicalize_input_type(g, input_types);
483 canonicalize_output_type(g, output_types);
485 // Input model checker for quantization
486 luci::QuantizePreCheckerPass input_model_checker{};
487 input_model_checker.run(g);
489 auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
491 ctx->input_model_dtype = str_to_dtype(input_model_dtype);
492 ctx->output_model_dtype = str_to_dtype(output_model_dtype);
493 ctx->granularity = str_to_granularity(granularity);
494 ctx->input_types = input_types;
495 ctx->output_types = output_types;
496 ctx->TF_style_maxpool = TF_style_maxpool;
498 for (auto layer_param : layer_params)
502 info.name = layer_param->name;
503 info.dtype = str_to_dtype(layer_param->dtype);
504 info.granularity = str_to_granularity(layer_param->granularity);
506 ctx->layers_info.emplace_back(info);
510 luci::QuantizeWithMinMaxPass quantizer(std::move(ctx));
514 auto verify_ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
516 verify_ctx->output_model_dtype = str_to_dtype(output_model_dtype);
517 verify_ctx->granularity = str_to_granularity(granularity);
518 verify_ctx->input_types = input_types;
519 verify_ctx->output_types = output_types;
520 verify_ctx->TF_style_maxpool = TF_style_maxpool;
522 for (auto layer_param : layer_params)
526 info.name = layer_param->name;
527 info.dtype = str_to_dtype(layer_param->dtype);
528 info.granularity = str_to_granularity(layer_param->granularity);
530 verify_ctx->layers_info.emplace_back(info);
534 // Verify the type/granularity of the quantized model
535 luci::QuantizedModelVerifier verifier(std::move(verify_ctx));
540 if (_options->query(Options::Algorithm::QuantizeWeights))
542 static const std::vector<std::string> qw_supported_input_model_dtype{"float32"};
543 static const std::vector<std::string> qw_supported_output_model_dtype{"int8", "int16"};
544 static const std::vector<std::string> qw_supported_granularity{"channel"};
546 auto input_model_dtype =
547 _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
548 auto output_model_dtype =
549 _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
550 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
552 if (!in_array(to_lower_case(input_model_dtype), qw_supported_input_model_dtype))
553 throw std::runtime_error("Unsupported input type. List of supported input type: " +
554 to_string(qw_supported_input_model_dtype));
556 if (!in_array(to_lower_case(output_model_dtype), qw_supported_output_model_dtype))
557 throw std::runtime_error("Unsupported output type. List of supported output type: " +
558 to_string(qw_supported_output_model_dtype));
560 if (!in_array(to_lower_case(granularity), qw_supported_granularity))
561 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
562 to_string(qw_supported_granularity));
563 auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
565 ctx->input_model_dtype = str_to_dtype(input_model_dtype);
566 ctx->output_model_dtype = str_to_dtype(output_model_dtype);
567 ctx->granularity = str_to_granularity(granularity);
569 luci::QuantizeWeightsPass weights_quantizer(std::move(ctx));
571 weights_quantizer.run(g);
575 if (_options->query(Options::Algorithm::Requantize))
577 static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
578 static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
580 auto input_model_dtype =
581 _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
582 auto output_model_dtype =
583 _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
585 if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
586 throw std::runtime_error("Unsupported input type. List of supported input types: " +
587 to_string(rq_supported_input_model_dtype));
589 if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
590 throw std::runtime_error("Unsupported output type. List of supported output types: " +
591 to_string(rq_supported_output_model_dtype));
593 luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
594 str_to_dtype(output_model_dtype));
598 // Force to write quantparam to specified tensors
599 // NOTE Only per-tensor (not per-channel) qparam can be written
600 if (_options->query(Options::Algorithm::ForceQuantParam))
602 ForceQuantParamPass::TensorVector tensors =
603 _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
604 auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
605 auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
607 // Cast scales/zero_points to proper types
608 ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
609 ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
611 ForceQuantParamPass fq(tensors, scales, zero_points);
615 // Copy quantparam of a tensor to another tensor
616 if (_options->query(Options::Algorithm::CopyQuantParam))
618 CopyQuantParamPass::TensorVector src_tensors =
619 _options->params(Options::AlgorithmParameters::Quantize_src_tensor_names);
620 CopyQuantParamPass::TensorVector dst_tensors =
621 _options->params(Options::AlgorithmParameters::Quantize_dst_tensor_names);
623 CopyQuantParamPass cq(src_tensors, dst_tensors);
627 // Convert quantized model to fake-quantized model
628 if (_options->query(Options::Algorithm::ConvertToFakeQuantizedModel))
630 luci::ConvertToFakeQuantizedModelPass fake_quantizer;
631 fake_quantizer.run(g);
636 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
637 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
638 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
640 // Remove redundant Dequantize Ops generated during fake quantization
641 phase.emplace_back(std::make_unique<luci::RemoveRedundantDequantizePass>());
642 // Fold Dequantize Ops generated during fake quantization
643 phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
645 ProgressReporter prog(g, logo::PhaseStrategy::Restart);
646 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
647 phase_runner.attach(&prog);
648 phase_runner.run(phase);
653 // Do Shape/Type inference
654 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
655 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
657 ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
658 logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
659 phase_runner.attach(&prog);
660 phase_runner.run(phase);