2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "Quantizer.h"
18 #include <luci/Service/Validate.h>
22 using namespace mpqsolver::core;
23 using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
24 using Algorithms = luci::CircleQuantizer::Options::Algorithm;
29 bool make_model_fake_quantized(luci::Module *module)
31 luci::CircleQuantizer quantizer;
33 auto options = quantizer.options();
34 options->enable(Algorithms::ConvertToFakeQuantizedModel);
36 for (size_t idx = 0; idx < module->size(); ++idx)
38 auto graph = module->graph(idx);
40 quantizer.quantize(graph);
41 if (!luci::validate(graph))
52 Quantizer::Quantizer(const std::string &input_dtype, const std::string &output_dtype)
53 : _input_dtype(input_dtype), _output_dtype(output_dtype)
57 void Quantizer::set_hook(const QuantizerHook *hook) { _hook = hook; }
60 * @brief quantize recorded module (min/max initialized) with specified parameters
61 * returns true on success
63 bool Quantizer::quantize(luci::Module *module, const std::string &quant_dtype,
64 LayerParams &layer_params)
69 static const std::string default_dtype = "float32";
70 static const std::string granularity_type = "channel";
72 luci::CircleQuantizer quantizer;
74 auto options = quantizer.options();
75 options->enable(Algorithms::QuantizeWithMinMax);
77 options->param(AlgorithmParameters::Quantize_input_model_dtype, default_dtype);
78 options->param(AlgorithmParameters::Quantize_output_model_dtype, quant_dtype);
79 options->param(AlgorithmParameters::Quantize_granularity, granularity_type);
80 options->param(AlgorithmParameters::Quantize_input_type, _input_dtype);
81 options->param(AlgorithmParameters::Quantize_output_type, _output_dtype);
82 options->param(AlgorithmParameters::Quantize_TF_style_maxpool, "False");
84 if (!layer_params.empty())
88 options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params);
90 catch (const std::runtime_error &e)
92 std::cerr << e.what() << '\n';
97 for (size_t idx = 0; idx < module->size(); ++idx)
99 auto graph = module->graph(idx);
100 // quantize the graph
101 quantizer.quantize(graph);
102 if (!luci::validate(graph))
104 std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
111 _hook->on_quantized(module);
118 * @brief fake_quantize recorded module (min/max initialized) with specified parameters
119 * returns true on success
121 bool Quantizer::fake_quantize(luci::Module *module, const std::string &quant_dtype,
122 LayerParams &layer_params)
124 if (!quantize(module, quant_dtype, layer_params))
127 if (!make_model_fake_quantized(module))