Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/CircleOptimizer.h"
18
19 #include "luci/Pass/FuseBatchNormWithTConv.h"
20 #include "luci/Pass/FuseBCQPass.h"
21 #include "luci/Pass/FuseInstanceNormPass.h"
22 #include "luci/Pass/ResolveCustomOpAddPass.h"
23 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
24 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
25 #include "luci/Pass/RequantizePass.h"
26 #include "luci/Pass/QuantizeWithMinMaxPass.h"
27 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
28 // TODO add more passes
29
30 #include "luci/Pass/ShapeInferencePass.h"
31 #include "luci/Pass/TypeInferencePass.h"
32
33 // logo passes
34 #include <logo/RemoveDeadNodeWithQueryPass.h>
35
36 #include "ProgressReporter.h"
37 #include "CircleOptimizerUtils.h"
38
39 #include <luci/IR/CircleNodes.h>
40 #include <logo/Phase.h>
41
42 #include <memory>
43
44 namespace
45 {
46
47 using namespace luci;
48
49 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
50 {
51 public:
52   void enable(Algorithm) final;
53   void param(AlgorithmParameters, const std::string &) final;
54   const std::string param(AlgorithmParameters) const final;
55   bool query(Algorithm) final;
56
57 private:
58   std::vector<Algorithm> _algorithms;
59   std::map<AlgorithmParameters, const std::string> _algorithm_params;
60 };
61
62 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
63
64 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
65 {
66   _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
67 }
68
69 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
70 {
71   auto param_str = _algorithm_params.find(param);
72   if (param_str != _algorithm_params.end())
73   {
74     return param_str->second;
75   }
76   else
77   {
78     return std::string();
79   }
80 }
81
82 bool OptimizeOptionsImpl::query(Algorithm algo)
83 {
84   std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
85   if (it == _algorithms.end())
86     return false;
87
88   return true;
89 }
90
91 } // namespace
92
93 namespace luci
94 {
95
96 CircleOptimizer::Options *CircleOptimizer::options(void)
97 {
98   if (_options == nullptr)
99   {
100     _options = std::make_unique<OptimizeOptionsImpl>();
101   }
102
103   return _options.get();
104 }
105
106 void CircleOptimizer::optimize(loco::Graph *g) const
107 {
108   logo::Phase phase;
109
110   /* TRANSFORM DECLARATION BEGIN */
111   if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
112   {
113     phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
114   }
115   if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
116   {
117     phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
118   }
119   if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
120   {
121     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
122   }
123   if (_options->query(Options::Algorithm::FuseInstanceNorm))
124   {
125     phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
126   }
127   if (_options->query(Options::Algorithm::FuseBCQ))
128   {
129     phase.emplace_back(std::make_unique<FuseBCQPass>());
130   }
131   if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
132   {
133     phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
134   }
135
136   // Shape inference is needed for added nodes doing above transformations
137   phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
138   phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
139   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
140   /* TRANSFORM DECLARATION END */
141
142   ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
143   logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
144   phase_runner.attach(&prog);
145   phase_runner.run(phase);
146 }
147
148 void CircleOptimizer::quantize(loco::Graph *g) const
149 {
150   // Fake quantization of weights
151   if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
152   {
153     static const std::vector<std::string> fakeq_supported_input_dtype{"float32"};
154     static const std::vector<std::string> fakeq_supported_output_dtype{"uint8"};
155     static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
156
157     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
158     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
159     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
160
161     if (!in_array(to_lower_case(input_dtype), fakeq_supported_input_dtype))
162       throw std::runtime_error("Unsupported input type. List of supported input type: " +
163                                to_string(fakeq_supported_input_dtype));
164
165     if (!in_array(to_lower_case(output_dtype), fakeq_supported_output_dtype))
166       throw std::runtime_error("Unsupported output type. List of supported output type: " +
167                                to_string(fakeq_supported_output_dtype));
168
169     if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
170       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
171                                to_string(fakeq_supported_granularity));
172
173     // Clear existing quantparams before doing fake quantization
174     for (auto node : loco::active_nodes(loco::output_nodes(g)))
175     {
176       auto circle_node = loco::must_cast<luci::CircleNode *>(node);
177       if (circle_node->quantparam() != nullptr)
178         circle_node->quantparam(nullptr);
179     }
180
181     luci::QuantizeDequantizeWeightsPass fake_quantizer(
182         str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
183     fake_quantizer.run(g);
184   }
185
186   // Actual quantization of weights, bias, and activation
187   if (_options->query(Options::Algorithm::QuantizeWithMinMax))
188   {
189     static const std::vector<std::string> qwmm_supported_input_dtype{"float32"};
190     static const std::vector<std::string> qwmm_supported_output_dtype{"uint8"};
191     static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
192
193     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
194     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
195     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
196
197     if (!in_array(to_lower_case(input_dtype), qwmm_supported_input_dtype))
198       throw std::runtime_error("Unsupported input type. List of supported input types: " +
199                                to_string(qwmm_supported_input_dtype));
200
201     if (!in_array(to_lower_case(output_dtype), qwmm_supported_output_dtype))
202       throw std::runtime_error("Unsupported output type. List of supported output types: " +
203                                to_string(qwmm_supported_output_dtype));
204
205     if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
206       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
207                                to_string(qwmm_supported_granularity));
208
209     luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
210                                            str_to_granularity(granularity));
211     quantizer.run(g);
212   }
213
214   // Requantize
215   if (_options->query(Options::Algorithm::Requantize))
216   {
217     static const std::vector<std::string> rq_supported_input_dtype{"int8"};
218     static const std::vector<std::string> rq_supported_output_dtype{"uint8"};
219
220     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
221     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
222
223     if (!in_array(to_lower_case(input_dtype), rq_supported_input_dtype))
224       throw std::runtime_error("Unsupported input type. List of supported input types: " +
225                                to_string(rq_supported_input_dtype));
226
227     if (!in_array(to_lower_case(output_dtype), rq_supported_output_dtype))
228       throw std::runtime_error("Unsupported output type. List of supported output types: " +
229                                to_string(rq_supported_output_dtype));
230
231     luci::RequantizePass requantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype));
232     requantizer.run(g);
233   }
234
235   logo::Phase phase;
236
237   // Do Shape/Type inference
238   phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
239   phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
240
241   ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
242   logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
243   phase_runner.attach(&prog);
244   phase_runner.run(phase);
245 }
246
247 } // namespace luci