Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / core / Quantizer.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "Quantizer.h"
18 #include <luci/Service/Validate.h>
19
20 #include <iostream>
21
22 using namespace mpqsolver::core;
23 using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
24 using Algorithms = luci::CircleQuantizer::Options::Algorithm;
25
26 namespace
27 {
28
29 bool make_model_fake_quantized(luci::Module *module)
30 {
31   luci::CircleQuantizer quantizer;
32
33   auto options = quantizer.options();
34   options->enable(Algorithms::ConvertToFakeQuantizedModel);
35
36   for (size_t idx = 0; idx < module->size(); ++idx)
37   {
38     auto graph = module->graph(idx);
39     // quantize the graph
40     quantizer.quantize(graph);
41     if (!luci::validate(graph))
42     {
43       return false;
44     }
45   }
46
47   return true;
48 }
49
50 } // namespace
51
52 Quantizer::Quantizer(const std::string &input_dtype, const std::string &output_dtype)
53   : _input_dtype(input_dtype), _output_dtype(output_dtype)
54 {
55 }
56
57 void Quantizer::set_hook(const QuantizerHook *hook) { _hook = hook; }
58
59 /**
60  * @brief quantize recorded module (min/max initialized) with specified parameters
61  * returns true on success
62  */
63 bool Quantizer::quantize(luci::Module *module, const std::string &quant_dtype,
64                          LayerParams &layer_params)
65 {
66   if (!module)
67     return false;
68
69   static const std::string default_dtype = "float32";
70   static const std::string granularity_type = "channel";
71
72   luci::CircleQuantizer quantizer;
73
74   auto options = quantizer.options();
75   options->enable(Algorithms::QuantizeWithMinMax);
76
77   options->param(AlgorithmParameters::Quantize_input_model_dtype, default_dtype);
78   options->param(AlgorithmParameters::Quantize_output_model_dtype, quant_dtype);
79   options->param(AlgorithmParameters::Quantize_granularity, granularity_type);
80   options->param(AlgorithmParameters::Quantize_input_type, _input_dtype);
81   options->param(AlgorithmParameters::Quantize_output_type, _output_dtype);
82   options->param(AlgorithmParameters::Quantize_TF_style_maxpool, "False");
83
84   if (!layer_params.empty())
85   {
86     try
87     {
88       options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params);
89     }
90     catch (const std::runtime_error &e)
91     {
92       std::cerr << e.what() << '\n';
93       return false;
94     }
95   }
96
97   for (size_t idx = 0; idx < module->size(); ++idx)
98   {
99     auto graph = module->graph(idx);
100     // quantize the graph
101     quantizer.quantize(graph);
102     if (!luci::validate(graph))
103     {
104       std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
105       return false;
106     }
107   }
108
109   if (_hook)
110   {
111     _hook->on_quantized(module);
112   }
113
114   return true;
115 }
116
117 /**
118  * @brief fake_quantize recorded module (min/max initialized) with specified parameters
119  * returns true on success
120  */
121 bool Quantizer::fake_quantize(luci::Module *module, const std::string &quant_dtype,
122                               LayerParams &layer_params)
123 {
124   if (!quantize(module, quant_dtype, layer_params))
125     return false;
126
127   if (!make_model_fake_quantized(module))
128     return false;
129
130   return true;
131 }