34f647301969fd72822fe4165e43bdb6212e99c1
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/CircleOptimizer.h"
18
19 #include "luci/Pass/FoldDequantizePass.h"
20 #include "luci/Pass/FuseActivationFunctionPass.h"
21 #include "luci/Pass/FuseAddWithTConvPass.h"
22 #include "luci/Pass/FuseBatchNormWithTConv.h"
23 #include "luci/Pass/FuseBCQPass.h"
24 #include "luci/Pass/FuseInstanceNormPass.h"
25 #include "luci/Pass/FusePreActivationBatchNormPass.h"
26 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
27 #include "luci/Pass/ResolveCustomOpAddPass.h"
28 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
29 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
30 #include "luci/Pass/RequantizePass.h"
31 #include "luci/Pass/QuantizeWithMinMaxPass.h"
32 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
33 #include "luci/Pass/SparsifyTensorPass.h"
34 // TODO add more passes
35
36 #include "luci/Pass/ShapeInferencePass.h"
37 #include "luci/Pass/TypeInferencePass.h"
38
39 // logo passes
40 #include <logo/RemoveDeadNodeWithQueryPass.h>
41
42 #include "ProgressReporter.h"
43 #include "CircleOptimizerUtils.h"
44
45 #include <luci/IR/CircleNodes.h>
46 #include <logo/Phase.h>
47
48 #include <memory>
49 #include <sstream>
50
51 namespace
52 {
53
54 std::vector<int> parseIntFromCommadelimitedStr(std::string str)
55 {
56   std::vector<int> ret;
57   std::istringstream is(str);
58   for (uint32_t i; is >> i;)
59   {
60     assert(i != ',');
61     ret.push_back(i);
62     if (is.peek() == ',')
63       is.ignore();
64   }
65   return ret;
66 }
67
68 using namespace luci;
69
70 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
71 {
72 public:
73   void enable(Algorithm) final;
74   void param(AlgorithmParameters, const std::string &) final;
75   const std::string param(AlgorithmParameters) const final;
76   bool query(Algorithm) final;
77
78 private:
79   std::vector<Algorithm> _algorithms;
80   std::map<AlgorithmParameters, const std::string> _algorithm_params;
81 };
82
83 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
84
85 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
86 {
87   _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
88 }
89
90 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
91 {
92   auto param_str = _algorithm_params.find(param);
93   if (param_str != _algorithm_params.end())
94   {
95     return param_str->second;
96   }
97   else
98   {
99     return std::string();
100   }
101 }
102
103 bool OptimizeOptionsImpl::query(Algorithm algo)
104 {
105   std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
106   if (it == _algorithms.end())
107     return false;
108
109   return true;
110 }
111
112 } // namespace
113
114 namespace luci
115 {
116
117 CircleOptimizer::Options *CircleOptimizer::options(void)
118 {
119   if (_options == nullptr)
120   {
121     _options = std::make_unique<OptimizeOptionsImpl>();
122   }
123
124   return _options.get();
125 }
126
127 void CircleOptimizer::optimize(loco::Graph *g) const
128 {
129   logo::Phase phase;
130
131   /* TRANSFORM DECLARATION BEGIN */
132   if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
133   {
134     phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
135   }
136   if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
137   {
138     phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
139   }
140   if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
141   {
142     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
143   }
144   if (_options->query(Options::Algorithm::FuseInstanceNorm))
145   {
146     phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
147   }
148   if (_options->query(Options::Algorithm::FuseBCQ))
149   {
150     phase.emplace_back(std::make_unique<FuseBCQPass>());
151   }
152   if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
153   {
154     phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
155   }
156   if (_options->query(Options::Algorithm::FuseAddWithTConv))
157   {
158     phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
159   }
160   if (_options->query(Options::Algorithm::FuseActivationFunction))
161   {
162     phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
163   }
164   if (_options->query(Options::Algorithm::FoldDequantize))
165   {
166     phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
167   }
168   if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
169   {
170     phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
171   }
172   if (_options->query(Options::Algorithm::MakeBatchNormGammaPositive))
173   {
174     phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
175   }
176
177   // Shape inference is needed for added nodes doing above transformations
178   phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
179   phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
180   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
181   /* TRANSFORM DECLARATION END */
182
183   ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
184   logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
185   phase_runner.attach(&prog);
186   phase_runner.run(phase);
187 }
188
189 void CircleOptimizer::quantize(loco::Graph *g) const
190 {
191   // Fake quantization of weights
192   if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
193   {
194     static const std::vector<std::string> fakeq_supported_input_dtype{"float32"};
195     static const std::vector<std::string> fakeq_supported_output_dtype{"uint8", "int16"};
196     static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
197
198     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
199     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
200     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
201
202     if (!in_array(to_lower_case(input_dtype), fakeq_supported_input_dtype))
203       throw std::runtime_error("Unsupported input type. List of supported input type: " +
204                                to_string(fakeq_supported_input_dtype));
205
206     if (!in_array(to_lower_case(output_dtype), fakeq_supported_output_dtype))
207       throw std::runtime_error("Unsupported output type. List of supported output type: " +
208                                to_string(fakeq_supported_output_dtype));
209
210     if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
211       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
212                                to_string(fakeq_supported_granularity));
213
214     if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
215         str_to_dtype(output_dtype) != loco::DataType::U8)
216       throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
217
218     // Clear existing quantparams before doing fake quantization
219     for (auto node : loco::active_nodes(loco::output_nodes(g)))
220     {
221       auto circle_node = loco::must_cast<luci::CircleNode *>(node);
222       if (circle_node->quantparam() != nullptr)
223         circle_node->quantparam(nullptr);
224     }
225
226     luci::QuantizeDequantizeWeightsPass fake_quantizer(
227         str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
228     fake_quantizer.run(g);
229   }
230
231   // Actual quantization of weights, bias, and activation
232   if (_options->query(Options::Algorithm::QuantizeWithMinMax))
233   {
234     static const std::vector<std::string> qwmm_supported_input_dtype{"float32"};
235     static const std::vector<std::string> qwmm_supported_output_dtype{"uint8", "int16"};
236     static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
237
238     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
239     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
240     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
241
242     if (!in_array(to_lower_case(input_dtype), qwmm_supported_input_dtype))
243       throw std::runtime_error("Unsupported input type. List of supported input types: " +
244                                to_string(qwmm_supported_input_dtype));
245
246     if (!in_array(to_lower_case(output_dtype), qwmm_supported_output_dtype))
247       throw std::runtime_error("Unsupported output type. List of supported output types: " +
248                                to_string(qwmm_supported_output_dtype));
249
250     if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
251       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
252                                to_string(qwmm_supported_granularity));
253
254     if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
255         str_to_dtype(output_dtype) != loco::DataType::U8)
256       throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
257
258     luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
259                                            str_to_granularity(granularity));
260     quantizer.run(g);
261   }
262
263   // Requantize
264   if (_options->query(Options::Algorithm::Requantize))
265   {
266     static const std::vector<std::string> rq_supported_input_dtype{"int8"};
267     static const std::vector<std::string> rq_supported_output_dtype{"uint8"};
268
269     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
270     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
271
272     if (!in_array(to_lower_case(input_dtype), rq_supported_input_dtype))
273       throw std::runtime_error("Unsupported input type. List of supported input types: " +
274                                to_string(rq_supported_input_dtype));
275
276     if (!in_array(to_lower_case(output_dtype), rq_supported_output_dtype))
277       throw std::runtime_error("Unsupported output type. List of supported output types: " +
278                                to_string(rq_supported_output_dtype));
279
280     luci::RequantizePass requantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype));
281     requantizer.run(g);
282   }
283
284   logo::Phase phase;
285
286   // Do Shape/Type inference
287   phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
288   phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
289
290   ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
291   logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
292   phase_runner.attach(&prog);
293   phase_runner.run(phase);
294 }
295
296 void CircleOptimizer::sparsify(loco::Graph *g) const
297 {
298   if (_options->query(Options::Algorithm::SparsifyTensorPass))
299   {
300     std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
301     std::string str_tarversal_order =
302         _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
303     std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
304     std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
305     std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
306
307     // traversal order
308     std::vector<int32_t> traversal_order = parseIntFromCommadelimitedStr(str_tarversal_order);
309     // format
310     std::vector<DimensionType> format;
311     std::istringstream is(str_format);
312     for (char c; is >> c;)
313     {
314       assert(c != ',');
315       if (c == 'd')
316         format.push_back(DimensionType::DENSE);
317       else if (c == 's')
318         format.push_back(DimensionType::SPARSE_CSR);
319       if (is.peek() == ',')
320         is.ignore();
321     }
322     // block size
323     std::vector<int32_t> block_size = parseIntFromCommadelimitedStr(str_block_size);
324     // block map
325     std::vector<int32_t> block_map = parseIntFromCommadelimitedStr(str_block_map);
326
327     luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,
328                                         block_map};
329     sparsifier.run(g);
330   }
331 }
332
333 } // namespace luci