Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/CircleOptimizer.h"
18
19 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
20 #include "luci/Pass/ExpandBroadcastConstPass.h"
21 #include "luci/Pass/FoldAddV2Pass.h"
22 #include "luci/Pass/FoldCastPass.h"
23 #include "luci/Pass/FoldDensifyPass.h"
24 #include "luci/Pass/FoldDepthwiseConv2DPass.h"
25 #include "luci/Pass/FoldDequantizePass.h"
26 #include "luci/Pass/FoldFullyConnectedPass.h"
27 #include "luci/Pass/FoldGatherPass.h"
28 #include "luci/Pass/FoldSparseToDensePass.h"
29 #include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
30 #include "luci/Pass/ForwardTransposeOpPass.h"
31 #include "luci/Pass/FuseActivationFunctionPass.h"
32 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
33 #include "luci/Pass/FuseAddWithTConvPass.h"
34 #include "luci/Pass/FuseBatchNormWithConvPass.h"
35 #include "luci/Pass/FuseBatchNormWithDwConvPass.h"
36 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
37 #include "luci/Pass/FuseBCQPass.h"
38 #include "luci/Pass/FuseInstanceNormPass.h"
39 #include "luci/Pass/FuseMeanWithMeanPass.h"
40 #include "luci/Pass/FusePreActivationBatchNormPass.h"
41 #include "luci/Pass/FusePReluPass.h"
42 #include "luci/Pass/FuseGeluPass.h"
43 #include "luci/Pass/FuseTransposeWithMeanPass.h"
44 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
45 #include "luci/Pass/RemoveDuplicateConstPass.h"
46 #include "luci/Pass/RemoveFakeQuantPass.h"
47 #include "luci/Pass/RemoveQuantDequantSeqPass.h"
48 #include "luci/Pass/RemoveRedundantReshapePass.h"
49 #include "luci/Pass/RemoveRedundantTransposePass.h"
50 #include "luci/Pass/RemoveRedundantQuantizePass.h"
51 #include "luci/Pass/RemoveUnnecessaryReshapePass.h"
52 #include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
53 #include "luci/Pass/RemoveUnnecessarySlicePass.h"
54 #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
55 #include "luci/Pass/RemoveUnnecessarySplitPass.h"
56 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
57 #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
58 #include "luci/Pass/ReplaceSubWithAddPass.h"
59 #include "luci/Pass/ResolveCustomOpAddPass.h"
60 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
61 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
62 #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
63 #include "luci/Pass/ResolveCustomOpSplitVPass.h"
64 #include "luci/Pass/SparsifyTensorPass.h"
65 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
66 #include "luci/Pass/SubstitutePackToReshapePass.h"
67 #include "luci/Pass/SubstitutePadV2ToPadPass.h"
68 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
69 #include "luci/Pass/SubstituteSqueezeToReshapePass.h"
70 #include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
71 #include "luci/Pass/SubstituteTransposeToReshapePass.h"
72 #include "luci/Pass/TransformMinMaxToRelu6Pass.h"
73 #include "luci/Pass/TransformMinReluToRelu6Pass.h"
74 #include "luci/Pass/DecomposeHardSwishPass.h"
75 #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
76 // TODO add more passes
77
78 #include "luci/Pass/CircleShapeInferencePass.h"
79 #include "luci/Pass/CircleTypeInferencePass.h"
80
81 // logo passes
82 #include <logo/RemoveDeadNodeWithQueryPass.h>
83
84 #include "ModulePhase.h"
85 #include "ProgressReporter.h"
86
87 #include <luci/IR/CircleNodes.h>
88 #include <logo/Phase.h>
89 #include <pepper/csv2vec.h>
90
91 #include <memory>
92 #include <sstream>
93
94 namespace
95 {
96
97 using namespace luci;
98
99 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
100 {
101 public:
102   void enable(Algorithm) final;
103   void param(AlgorithmParameters, const std::string &) final;
104   const std::string param(AlgorithmParameters) const final;
105   bool query(Algorithm) final;
106
107 private:
108   std::vector<Algorithm> _algorithms;
109   std::map<AlgorithmParameters, const std::string> _algorithm_params;
110 };
111
112 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
113
114 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
115 {
116   _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
117 }
118
119 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
120 {
121   auto param_str = _algorithm_params.find(param);
122   if (param_str != _algorithm_params.end())
123   {
124     return param_str->second;
125   }
126   else
127   {
128     return std::string();
129   }
130 }
131
132 bool OptimizeOptionsImpl::query(Algorithm algo)
133 {
134   std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
135   if (it == _algorithms.end())
136     return false;
137
138   return true;
139 }
140
141 // TODO Make a struct for args
142 void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc,
143                           bool fuse_gelu)
144 {
145   logo::Phase phase;
146
147   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
148   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
149   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
150
151   // Resolve custom Ops
152   phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
153   phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
154   phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
155   phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
156   phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
157
158   // Fuse FullyConnected with Add
159   // Why we perform FuseAddWithFullyConnectedPass before ConvertNCHWToNHWCPass?
160   // FullyConnected Op's layout is not changed in ConvertNCHWToNHWCPass, while
161   // Add Op's layer is changed from NCHW to NHWC.
162   // This disables fusion of Add and FullyConnected after ConvertNCHWToNHWC.
163   if (fuse_fc)
164     phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>());
165
166   // Fuse decomposed ops to Gelu Op
167   // Why here? ConverNCHWToNHWCPass inserts additional Ops, so it is better to fuse
168   // Gelu in advance.
169   if (fuse_gelu)
170     phase.emplace_back(std::make_unique<luci::FuseGeluPass>());
171
172   phase.emplace_back(
173     std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
174
175   ProgressReporter prog(g, logo::PhaseStrategy::Restart);
176   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
177   phase_runner.attach(&prog);
178   phase_runner.run(phase);
179 }
180
181 } // namespace
182
183 namespace luci
184 {
185
186 CircleOptimizer::Options *CircleOptimizer::options(void)
187 {
188   if (_options == nullptr)
189   {
190     _options = std::make_unique<OptimizeOptionsImpl>();
191   }
192
193   return _options.get();
194 }
195
196 void CircleOptimizer::optimize(luci::Module *m) const
197 {
198   luci::Phase phase;
199
200   // Following passes are needed everytime when other passes create new node or modify some nodes.
201   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
202   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
203
204   if (_options->query(Options::Algorithm::FuseBCQ))
205   {
206     phase.emplace_back(std::make_unique<FuseBCQPass>());
207   }
208
209   ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
210   PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
211   phase_runner.attach(&prog);
212   phase_runner.run(phase);
213 }
214
215 void CircleOptimizer::optimize(loco::Graph *g) const
216 {
217   logo::Phase phase;
218
219   // Conversion from NCHW to NHWC is done first to avoid interference with other optimizations.
220   if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
221   {
222     bool preserve_input =
223       _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_input_shape) != "true";
224     bool preserve_output =
225       _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
226
227     bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected);
228     bool fuse_gelu = _options->query(Options::Algorithm::FuseGelu);
229
230     convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc, fuse_gelu);
231   }
232
233   /* TRANSFORM DECLARATION BEGIN */
234   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
235
236   // Following passes are needed everytime when other passes create new node or modify some nodes.
237   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
238   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
239
240   if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
241   {
242     phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
243   }
244   if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
245   {
246     phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
247   }
248   if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
249   {
250     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
251   }
252   if (_options->query(Options::Algorithm::FuseMeanWithMean))
253   {
254     phase.emplace_back(std::make_unique<FuseMeanWithMeanPass>());
255   }
256   if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax))
257   {
258     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
259   }
260   if (_options->query(Options::Algorithm::ResolveCustomOpSplitV))
261   {
262     phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
263   }
264   if (_options->query(Options::Algorithm::FuseInstanceNorm))
265   {
266     phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
267   }
268   if (_options->query(Options::Algorithm::FuseBatchNormWithConv))
269   {
270     phase.emplace_back(std::make_unique<FuseBatchNormWithConvPass>());
271   }
272   if (_options->query(Options::Algorithm::FuseBatchNormWithDwConv))
273   {
274     phase.emplace_back(std::make_unique<FuseBatchNormWithDwConvPass>());
275   }
276   if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
277   {
278     phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
279   }
280   if (_options->query(Options::Algorithm::FuseAddWithFullyConnected))
281   {
282     phase.emplace_back(std::make_unique<FuseAddWithFullyConnectedPass>());
283   }
284   if (_options->query(Options::Algorithm::FuseAddWithTConv))
285   {
286     phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
287   }
288   if (_options->query(Options::Algorithm::FuseActivationFunction))
289   {
290     phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
291   }
292   if (_options->query(Options::Algorithm::FusePRelu))
293   {
294     phase.emplace_back(std::make_unique<FusePReluPass>());
295   }
296   if (_options->query(Options::Algorithm::FuseGelu))
297   {
298     phase.emplace_back(std::make_unique<FuseGeluPass>());
299   }
300   if (_options->query(Options::Algorithm::FuseTransposeWithMean))
301   {
302     phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
303   }
304   if (_options->query(Options::Algorithm::FoldAddV2))
305   {
306     phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>());
307   }
308   if (_options->query(Options::Algorithm::FoldCast))
309   {
310     phase.emplace_back(std::make_unique<luci::FoldCastPass>());
311   }
312   if (_options->query(Options::Algorithm::FoldDensify))
313   {
314     phase.emplace_back(std::make_unique<luci::FoldDensifyPass>());
315   }
316   if (_options->query(Options::Algorithm::FoldDepthwiseConv2D))
317   {
318     phase.emplace_back(std::make_unique<luci::FoldDepthwiseConv2DPass>());
319   }
320   if (_options->query(Options::Algorithm::FoldDequantize))
321   {
322     phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
323   }
324   if (_options->query(Options::Algorithm::FoldFullyConnected))
325   {
326     phase.emplace_back(std::make_unique<luci::FoldFullyConnectedPass>());
327   }
328   if (_options->query(Options::Algorithm::FoldGather))
329   {
330     phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
331   }
332   if (_options->query(Options::Algorithm::FoldSparseToDense))
333   {
334     phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
335   }
336   if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
337   {
338     phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
339   }
340   if (_options->query(Options::Algorithm::MakeBatchNormGammaPositive))
341   {
342     phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
343   }
344   if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32))
345   {
346     phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
347   }
348   if (_options->query(Options::Algorithm::ExpandBroadcastConst))
349   {
350     phase.emplace_back(std::make_unique<luci::ExpandBroadcastConstPass>());
351   }
352   if (_options->query(Options::Algorithm::RemoveDuplicateConst))
353   {
354     phase.emplace_back(std::make_unique<luci::RemoveDuplicateConstPass>());
355   }
356   if (_options->query(Options::Algorithm::RemoveFakeQuant))
357   {
358     phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>());
359   }
360   if (_options->query(Options::Algorithm::RemoveQuantDequantSeq))
361   {
362     phase.emplace_back(std::make_unique<luci::RemoveQuantDequantSeqPass>());
363   }
364   if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
365   {
366     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
367     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapeNetPass>());
368   }
369   if (_options->query(Options::Algorithm::RemoveUnnecessarySlice))
370   {
371     phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySlicePass>());
372   }
373   if (_options->query(Options::Algorithm::RemoveUnnecessaryStridedSlice))
374   {
375     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryStridedSlicePass>());
376   }
377   if (_options->query(Options::Algorithm::RemoveUnnecessarySplit))
378   {
379     phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySplitPass>());
380   }
381   if (_options->query(Options::Algorithm::RemoveRedundantReshape))
382   {
383     phase.emplace_back(std::make_unique<luci::RemoveRedundantReshapePass>());
384   }
385   if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
386   {
387     phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
388   }
389   if (_options->query(Options::Algorithm::RemoveRedundantQuantize))
390   {
391     phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
392   }
393   if (_options->query(Options::Algorithm::ReplaceNonConstFCWithBatchMatMul))
394   {
395     phase.emplace_back(std::make_unique<luci::ReplaceNonConstFCWithBatchMatMulPass>());
396   }
397   if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
398   {
399     phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
400   }
401   if (_options->query(Options::Algorithm::ReplaceSubWithAdd))
402   {
403     phase.emplace_back(std::make_unique<luci::ReplaceSubWithAddPass>());
404   }
405   if (_options->query(Options::Algorithm::SubstitutePackToReshape))
406   {
407     phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
408   }
409   if (_options->query(Options::Algorithm::SubstitutePadV2ToPad))
410   {
411     phase.emplace_back(std::make_unique<luci::SubstitutePadV2ToPadPass>());
412   }
413   if (_options->query(Options::Algorithm::SubstituteSplitVToSplit))
414   {
415     phase.emplace_back(std::make_unique<luci::SubstituteSplitVToSplitPass>());
416   }
417   if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape))
418   {
419     phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>());
420   }
421   if (_options->query(Options::Algorithm::SubstituteStridedSliceToReshape))
422   {
423     phase.emplace_back(std::make_unique<luci::SubstituteStridedSliceToReshapePass>());
424   }
425   if (_options->query(Options::Algorithm::SubstituteTransposeToReshape))
426   {
427     phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>());
428   }
429   if (_options->query(Options::Algorithm::TransformMinMaxToRelu6Pass))
430   {
431     phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>());
432   }
433   if (_options->query(Options::Algorithm::TransformMinReluToRelu6Pass))
434   {
435     phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
436   }
437   if (_options->query(Options::Algorithm::DecomposeHardSwishPass))
438   {
439     phase.emplace_back(std::make_unique<luci::DecomposeHardSwishPass>());
440   }
441   if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM))
442   {
443     phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>());
444   }
445   // Forward Reshape/Transpose is done after
446   // 1. SubstituteXXXToReshape
447   // 2. RemoveRedundantReshape/Transpose
448   // See https://github.com/Samsung/ONE/pull/10596 for more details
449   if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
450   {
451     phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
452   }
453   if (_options->query(Options::Algorithm::ForwardTransposeOp))
454   {
455     phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
456   }
457
458   /* TRANSFORM DECLARATION END */
459
460   ProgressReporter prog(g, logo::PhaseStrategy::Restart);
461   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
462   phase_runner.attach(&prog);
463   phase_runner.run(phase);
464 }
465
466 void CircleOptimizer::sparsify(loco::Graph *g) const
467 {
468   if (_options->query(Options::Algorithm::SparsifyTensorPass))
469   {
470     std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
471     std::string str_tarversal_order =
472       _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
473     std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
474     std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
475     std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
476
477     // traversal order
478     std::vector<int32_t> traversal_order = pepper::csv_to_vector<int32_t>(str_tarversal_order);
479     // format
480     std::vector<DimensionType> format;
481     std::istringstream is(str_format);
482     for (char c; is >> c;)
483     {
484       assert(c != ',');
485       if (c == 'd')
486         format.push_back(DimensionType::DENSE);
487       else if (c == 's')
488         format.push_back(DimensionType::SPARSE_CSR);
489       if (is.peek() == ',')
490         is.ignore();
491     }
492     // block size
493     std::vector<int32_t> block_size = pepper::csv_to_vector<int32_t>(str_block_size);
494     // block map
495     std::vector<int32_t> block_map = pepper::csv_to_vector<int32_t>(str_block_map);
496
497     luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,
498                                         block_map};
499     sparsifier.run(g);
500   }
501 }
502
503 } // namespace luci