Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/CircleOptimizer.h"
18
19 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
20 #include "luci/Pass/ExpandBroadcastConstPass.h"
21 #include "luci/Pass/FoldAddV2Pass.h"
22 #include "luci/Pass/FoldCastPass.h"
23 #include "luci/Pass/FoldDepthwiseConv2DPass.h"
24 #include "luci/Pass/FoldDequantizePass.h"
25 #include "luci/Pass/FoldSparseToDensePass.h"
26 #include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
27 #include "luci/Pass/ForceQuantParamPass.h"
28 #include "luci/Pass/FuseActivationFunctionPass.h"
29 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
30 #include "luci/Pass/FuseAddWithTConvPass.h"
31 #include "luci/Pass/FuseBatchNormWithConvPass.h"
32 #include "luci/Pass/FuseBatchNormWithDwConvPass.h"
33 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
34 #include "luci/Pass/FuseBCQPass.h"
35 #include "luci/Pass/FuseInstanceNormPass.h"
36 #include "luci/Pass/FuseMeanWithMeanPass.h"
37 #include "luci/Pass/FusePreActivationBatchNormPass.h"
38 #include "luci/Pass/FuseTransposeWithMeanPass.h"
39 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
40 #include "luci/Pass/PropagateQuantParamPass.h"
41 #include "luci/Pass/RemoveFakeQuantPass.h"
42 #include "luci/Pass/RemoveQuantDequantSeqPass.h"
43 #include "luci/Pass/RemoveRedundantReshapePass.h"
44 #include "luci/Pass/RemoveRedundantTransposePass.h"
45 #include "luci/Pass/RemoveUnnecessaryReshapePass.h"
46 #include "luci/Pass/RemoveUnnecessarySlicePass.h"
47 #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
48 #include "luci/Pass/RemoveUnnecessarySplitPass.h"
49 #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
50 #include "luci/Pass/ReplaceSubWithAddPass.h"
51 #include "luci/Pass/ResolveCustomOpAddPass.h"
52 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
53 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
54 #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
55 #include "luci/Pass/RequantizePass.h"
56 #include "luci/Pass/QuantizeWithMinMaxPass.h"
57 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
58 #include "luci/Pass/SparsifyTensorPass.h"
59 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
60 #include "luci/Pass/SubstitutePackToReshapePass.h"
61 #include "luci/Pass/SubstitutePadV2ToPadPass.h"
62 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
63 #include "luci/Pass/SubstituteSqueezeToReshapePass.h"
64 #include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
65 #include "luci/Pass/SubstituteTransposeToReshapePass.h"
66 #include "luci/Pass/TransformMinMaxToRelu6Pass.h"
67 #include "luci/Pass/TransformMinReluToRelu6Pass.h"
68 // TODO add more passes
69
70 #include "luci/Pass/CircleShapeInferencePass.h"
71 #include "luci/Pass/CircleTypeInferencePass.h"
72
73 // logo passes
74 #include <logo/RemoveDeadNodeWithQueryPass.h>
75
76 #include "ModulePhase.h"
77 #include "ProgressReporter.h"
78 #include "helpers/Strings.h"
79
80 #include "QuantizedModelVerifier.h"
81
82 #include <luci/IR/CircleNodes.h>
83 #include <logo/Phase.h>
84 #include <pepper/csv2vec.h>
85
86 #include <memory>
87 #include <sstream>
88
89 namespace
90 {
91
92 using namespace luci;
93
94 template <typename T> T lexical_cast(const std::string &str)
95 {
96   std::istringstream ss;
97   ss.str(str);
98   T data;
99   ss >> data;
100   return data;
101 }
102
103 template <typename T> std::vector<T> lexical_cast(std::vector<std::string> &sv)
104 {
105   std::vector<T> result;
106   std::transform(sv.begin(), sv.end(), std::back_inserter(result),
107                  [](std::string str) -> T { return lexical_cast<T>(str); });
108   return result;
109 }
110
111 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
112 {
113 public:
114   void enable(Algorithm) final;
115   void param(AlgorithmParameters, const std::string &) final;
116   const std::string param(AlgorithmParameters) const final;
117   void params(AlgorithmParameters, std::vector<std::string> &) final;
118   std::vector<std::string> params(AlgorithmParameters) const final;
119   bool query(Algorithm) final;
120
121 private:
122   std::vector<Algorithm> _algorithms;
123   std::map<AlgorithmParameters, const std::string> _algorithm_params;
124   std::map<AlgorithmParameters, std::vector<std::string>> _multiple_params;
125 };
126
127 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
128
129 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
130 {
131   _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
132 }
133
134 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
135 {
136   auto param_str = _algorithm_params.find(param);
137   if (param_str != _algorithm_params.end())
138   {
139     return param_str->second;
140   }
141   else
142   {
143     return std::string();
144   }
145 }
146
147 void OptimizeOptionsImpl::params(AlgorithmParameters param, std::vector<std::string> &vec)
148 {
149   _multiple_params[param] = vec;
150 }
151
152 std::vector<std::string> OptimizeOptionsImpl::params(AlgorithmParameters param) const
153 {
154   auto param_vec = _multiple_params.find(param);
155   if (param_vec != _multiple_params.end())
156   {
157     return param_vec->second;
158   }
159   else
160   {
161     return std::vector<std::string>();
162   }
163 }
164
165 bool OptimizeOptionsImpl::query(Algorithm algo)
166 {
167   std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
168   if (it == _algorithms.end())
169     return false;
170
171   return true;
172 }
173
174 void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output)
175 {
176   logo::Phase phase;
177
178   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
179   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
180   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
181
182   phase.emplace_back(
183     std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
184
185   ProgressReporter prog(g, logo::PhaseStrategy::Restart);
186   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
187   phase_runner.attach(&prog);
188   phase_runner.run(phase);
189 }
190
191 } // namespace
192
193 namespace luci
194 {
195
196 CircleOptimizer::Options *CircleOptimizer::options(void)
197 {
198   if (_options == nullptr)
199   {
200     _options = std::make_unique<OptimizeOptionsImpl>();
201   }
202
203   return _options.get();
204 }
205
206 void CircleOptimizer::optimize(luci::Module *m) const
207 {
208   luci::Phase phase;
209
210   // Following passes are needed everytime when other passes create new node or modify some nodes.
211   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
212   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
213
214   if (_options->query(Options::Algorithm::FuseBCQ))
215   {
216     phase.emplace_back(std::make_unique<FuseBCQPass>());
217   }
218
219   ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
220   PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
221   phase_runner.attach(&prog);
222   phase_runner.run(phase);
223 }
224
225 void CircleOptimizer::optimize(loco::Graph *g) const
226 {
227   logo::Phase phase;
228
229   // Conversion from NCHW to NHWC is done first to avoid interference with other optimizations.
230   if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
231   {
232     bool preserve_input =
233       _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_input_shape) != "true";
234     bool preserve_output =
235       _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
236
237     convert_nchw_to_nhwc(g, preserve_input, preserve_output);
238   }
239
240   /* TRANSFORM DECLARATION BEGIN */
241   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
242
243   // Following passes are needed everytime when other passes create new node or modify some nodes.
244   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
245   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
246
247   if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
248   {
249     phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
250   }
251   if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
252   {
253     phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
254   }
255   if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
256   {
257     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
258   }
259   if (_options->query(Options::Algorithm::FuseMeanWithMean))
260   {
261     phase.emplace_back(std::make_unique<FuseMeanWithMeanPass>());
262   }
263   if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax))
264   {
265     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
266   }
267   if (_options->query(Options::Algorithm::FuseInstanceNorm))
268   {
269     phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
270   }
271   if (_options->query(Options::Algorithm::FuseBatchNormWithConv))
272   {
273     phase.emplace_back(std::make_unique<FuseBatchNormWithConvPass>());
274   }
275   if (_options->query(Options::Algorithm::FuseBatchNormWithDwConv))
276   {
277     phase.emplace_back(std::make_unique<FuseBatchNormWithDwConvPass>());
278   }
279   if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
280   {
281     phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
282   }
283   if (_options->query(Options::Algorithm::FuseAddWithFullyConnected))
284   {
285     phase.emplace_back(std::make_unique<FuseAddWithFullyConnectedPass>());
286   }
287   if (_options->query(Options::Algorithm::FuseAddWithTConv))
288   {
289     phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
290   }
291   if (_options->query(Options::Algorithm::FuseActivationFunction))
292   {
293     phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
294   }
295   if (_options->query(Options::Algorithm::FuseTransposeWithMean))
296   {
297     phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
298   }
299   if (_options->query(Options::Algorithm::FoldAddV2))
300   {
301     phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>());
302   }
303   if (_options->query(Options::Algorithm::FoldCast))
304   {
305     phase.emplace_back(std::make_unique<luci::FoldCastPass>());
306   }
307   if (_options->query(Options::Algorithm::FoldDepthwiseConv2D))
308   {
309     phase.emplace_back(std::make_unique<luci::FoldDepthwiseConv2DPass>());
310   }
311   if (_options->query(Options::Algorithm::FoldDequantize))
312   {
313     phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
314   }
315   if (_options->query(Options::Algorithm::FoldSparseToDense))
316   {
317     phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
318   }
319   if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
320   {
321     phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
322   }
323   if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
324   {
325     phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
326   }
327   if (_options->query(Options::Algorithm::MakeBatchNormGammaPositive))
328   {
329     phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
330   }
331   if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32))
332   {
333     phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
334   }
335   if (_options->query(Options::Algorithm::ExpandBroadcastConst))
336   {
337     phase.emplace_back(std::make_unique<luci::ExpandBroadcastConstPass>());
338   }
339   if (_options->query(Options::Algorithm::RemoveFakeQuant))
340   {
341     phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>());
342   }
343   if (_options->query(Options::Algorithm::RemoveQuantDequantSeq))
344   {
345     phase.emplace_back(std::make_unique<luci::RemoveQuantDequantSeqPass>());
346   }
347   if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
348   {
349     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
350   }
351   if (_options->query(Options::Algorithm::RemoveUnnecessarySlice))
352   {
353     phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySlicePass>());
354   }
355   if (_options->query(Options::Algorithm::RemoveUnnecessaryStridedSlice))
356   {
357     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryStridedSlicePass>());
358   }
359   if (_options->query(Options::Algorithm::RemoveUnnecessarySplit))
360   {
361     phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySplitPass>());
362   }
363   if (_options->query(Options::Algorithm::RemoveRedundantReshape))
364   {
365     phase.emplace_back(std::make_unique<luci::RemoveRedundantReshapePass>());
366   }
367   if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
368   {
369     phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
370   }
371   if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
372   {
373     phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
374   }
375   if (_options->query(Options::Algorithm::ReplaceSubWithAdd))
376   {
377     phase.emplace_back(std::make_unique<luci::ReplaceSubWithAddPass>());
378   }
379   if (_options->query(Options::Algorithm::SubstitutePackToReshape))
380   {
381     phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
382   }
383   if (_options->query(Options::Algorithm::SubstitutePadV2ToPad))
384   {
385     phase.emplace_back(std::make_unique<luci::SubstitutePadV2ToPadPass>());
386   }
387   if (_options->query(Options::Algorithm::SubstituteSplitVToSplit))
388   {
389     phase.emplace_back(std::make_unique<luci::SubstituteSplitVToSplitPass>());
390   }
391   if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape))
392   {
393     phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>());
394   }
395   if (_options->query(Options::Algorithm::SubstituteStridedSliceToReshape))
396   {
397     phase.emplace_back(std::make_unique<luci::SubstituteStridedSliceToReshapePass>());
398   }
399   if (_options->query(Options::Algorithm::SubstituteTransposeToReshape))
400   {
401     phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>());
402   }
403   if (_options->query(Options::Algorithm::TransformMinMaxToRelu6Pass))
404   {
405     phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>());
406   }
407   if (_options->query(Options::Algorithm::TransformMinReluToRelu6Pass))
408   {
409     phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
410   }
411
412   /* TRANSFORM DECLARATION END */
413
414   ProgressReporter prog(g, logo::PhaseStrategy::Restart);
415   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
416   phase_runner.attach(&prog);
417   phase_runner.run(phase);
418 }
419
420 void CircleOptimizer::quantize(loco::Graph *g) const
421 {
422   // Fake quantization of weights
423   if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
424   {
425     static const std::vector<std::string> fakeq_supported_input_model_dtype{"float32"};
426     static const std::vector<std::string> fakeq_supported_output_model_dtype{"uint8", "int16"};
427     static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
428
429     auto input_model_dtype =
430       _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
431     auto output_model_dtype =
432       _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
433     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
434
435     if (!in_array(to_lower_case(input_model_dtype), fakeq_supported_input_model_dtype))
436       throw std::runtime_error("Unsupported input type. List of supported input type: " +
437                                to_string(fakeq_supported_input_model_dtype));
438
439     if (!in_array(to_lower_case(output_model_dtype), fakeq_supported_output_model_dtype))
440       throw std::runtime_error("Unsupported output type. List of supported output type: " +
441                                to_string(fakeq_supported_output_model_dtype));
442
443     if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
444       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
445                                to_string(fakeq_supported_granularity));
446
447     if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
448         str_to_dtype(output_model_dtype) != loco::DataType::U8)
449       throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
450
451     // Clear existing quantparams before doing fake quantization
452     for (auto node : loco::active_nodes(loco::output_nodes(g)))
453     {
454       auto circle_node = loco::must_cast<luci::CircleNode *>(node);
455       if (circle_node->quantparam() != nullptr)
456         circle_node->quantparam(nullptr);
457     }
458
459     luci::QuantizeDequantizeWeightsPass fake_quantizer(str_to_dtype(input_model_dtype),
460                                                        str_to_dtype(output_model_dtype),
461                                                        str_to_granularity(granularity));
462     fake_quantizer.run(g);
463   }
464
465   // Actual quantization of weights, bias, and activation
466   if (_options->query(Options::Algorithm::QuantizeWithMinMax))
467   {
468     static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
469     static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
470     static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
471     static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
472     static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
473
474     auto input_model_dtype =
475       _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
476     auto output_model_dtype =
477       _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
478     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
479     auto input_type = _options->param(Options::AlgorithmParameters::Quantize_input_type);
480     if (input_type.empty())
481       input_type = output_model_dtype;
482     auto output_type = _options->param(Options::AlgorithmParameters::Quantize_output_type);
483     if (output_type.empty())
484       output_type = output_model_dtype;
485
486     if (!in_array(to_lower_case(input_model_dtype), qwmm_supported_input_model_dtype))
487       throw std::runtime_error("Unsupported input type. List of supported input types: " +
488                                to_string(qwmm_supported_input_model_dtype));
489
490     if (!in_array(to_lower_case(output_model_dtype), qwmm_supported_output_model_dtype))
491       throw std::runtime_error("Unsupported output type. List of supported output types: " +
492                                to_string(qwmm_supported_output_model_dtype));
493
494     if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
495       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
496                                to_string(qwmm_supported_granularity));
497
498     if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
499       throw std::runtime_error("Unsupported input type. List of supported input types: " +
500                                to_string(qwmm_supported_input_type));
501
502     if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
503       throw std::runtime_error("Unsupported output type. List of supported output types: " +
504                                to_string(qwmm_supported_output_type));
505
506     if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
507         str_to_dtype(output_model_dtype) != loco::DataType::U8)
508       throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
509
510     luci::QuantizeWithMinMaxPass quantizer(
511       str_to_dtype(input_model_dtype), str_to_dtype(output_model_dtype),
512       str_to_granularity(granularity), str_to_dtype(input_type), str_to_dtype(output_type));
513     quantizer.run(g);
514
515     // Post-quantization optimizations
516     logo::Phase phase;
517
518     phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>());
519
520     phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
521     phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
522     phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
523
524     ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
525     logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
526     phase_runner.attach(&prog);
527     phase_runner.run(phase);
528
529     // Verify the type/granularity of the quantized model
530     luci::QuantizedModelVerifier verifier(str_to_dtype(output_model_dtype),
531                                           str_to_granularity(granularity));
532     verifier.verify(g);
533   }
534
535   // Requantize
536   if (_options->query(Options::Algorithm::Requantize))
537   {
538     static const std::vector<std::string> rq_supported_input_model_dtype{"int8"};
539     static const std::vector<std::string> rq_supported_output_model_dtype{"uint8"};
540
541     auto input_model_dtype =
542       _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
543     auto output_model_dtype =
544       _options->param(Options::AlgorithmParameters::Quantize_output_model_dtype);
545
546     if (!in_array(to_lower_case(input_model_dtype), rq_supported_input_model_dtype))
547       throw std::runtime_error("Unsupported input type. List of supported input types: " +
548                                to_string(rq_supported_input_model_dtype));
549
550     if (!in_array(to_lower_case(output_model_dtype), rq_supported_output_model_dtype))
551       throw std::runtime_error("Unsupported output type. List of supported output types: " +
552                                to_string(rq_supported_output_model_dtype));
553
554     luci::RequantizePass requantizer(str_to_dtype(input_model_dtype),
555                                      str_to_dtype(output_model_dtype));
556     requantizer.run(g);
557   }
558
559   // Force to write quantparam to specified tensors
560   // NOTE Only per-tensor (not per-channel) qparam can be written
561   if (_options->query(Options::Algorithm::ForceQuantParam))
562   {
563     ForceQuantParamPass::TensorVector tensors =
564       _options->params(Options::AlgorithmParameters::Quantize_tensor_names);
565     auto str_scales = _options->params(Options::AlgorithmParameters::Quantize_scales);
566     auto str_zero_points = _options->params(Options::AlgorithmParameters::Quantize_zero_points);
567
568     // Cast scales/zero_points to proper types
569     ForceQuantParamPass::ScaleVector scales = lexical_cast<float>(str_scales);
570     ForceQuantParamPass::ZPVector zero_points = lexical_cast<int64_t>(str_zero_points);
571
572     ForceQuantParamPass fq(tensors, scales, zero_points);
573     fq.run(g);
574   }
575
576   logo::Phase phase;
577
578   // Do Shape/Type inference
579   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
580   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
581
582   ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
583   logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
584   phase_runner.attach(&prog);
585   phase_runner.run(phase);
586 }
587
588 void CircleOptimizer::sparsify(loco::Graph *g) const
589 {
590   if (_options->query(Options::Algorithm::SparsifyTensorPass))
591   {
592     std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
593     std::string str_tarversal_order =
594       _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
595     std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
596     std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
597     std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
598
599     // traversal order
600     std::vector<int32_t> traversal_order = pepper::csv_to_vector<int32_t>(str_tarversal_order);
601     // format
602     std::vector<DimensionType> format;
603     std::istringstream is(str_format);
604     for (char c; is >> c;)
605     {
606       assert(c != ',');
607       if (c == 'd')
608         format.push_back(DimensionType::DENSE);
609       else if (c == 's')
610         format.push_back(DimensionType::SPARSE_CSR);
611       if (is.peek() == ',')
612         is.ignore();
613     }
614     // block size
615     std::vector<int32_t> block_size = pepper::csv_to_vector<int32_t>(str_block_size);
616     // block map
617     std::vector<int32_t> block_map = pepper::csv_to_vector<int32_t>(str_block_map);
618
619     luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,
620                                         block_map};
621     sparsifier.run(g);
622   }
623 }
624
625 } // namespace luci