5e1613ad963fd1834a17fc1deb326842a6b8ed9e
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/CircleOptimizer.h"
18
19 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
20 #include "luci/Pass/ExpandBroadcastConstPass.h"
21 #include "luci/Pass/FoldAddV2Pass.h"
22 #include "luci/Pass/FoldCastPass.h"
23 #include "luci/Pass/FoldDensifyPass.h"
24 #include "luci/Pass/FoldDepthwiseConv2DPass.h"
25 #include "luci/Pass/FoldDequantizePass.h"
26 #include "luci/Pass/FoldFullyConnectedPass.h"
27 #include "luci/Pass/FoldGatherPass.h"
28 #include "luci/Pass/FoldSparseToDensePass.h"
29 #include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
30 #include "luci/Pass/ForwardTransposeOpPass.h"
31 #include "luci/Pass/FuseActivationFunctionPass.h"
32 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
33 #include "luci/Pass/FuseAddWithTConvPass.h"
34 #include "luci/Pass/FuseBatchNormWithConvPass.h"
35 #include "luci/Pass/FuseBatchNormWithDwConvPass.h"
36 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
37 #include "luci/Pass/FuseBCQPass.h"
38 #include "luci/Pass/FuseInstanceNormPass.h"
39 #include "luci/Pass/FuseMeanWithMeanPass.h"
40 #include "luci/Pass/FusePreActivationBatchNormPass.h"
41 #include "luci/Pass/FusePReluPass.h"
42 #include "luci/Pass/FuseTransposeWithMeanPass.h"
43 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
44 #include "luci/Pass/RemoveDuplicateConstPass.h"
45 #include "luci/Pass/RemoveFakeQuantPass.h"
46 #include "luci/Pass/RemoveQuantDequantSeqPass.h"
47 #include "luci/Pass/RemoveRedundantReshapePass.h"
48 #include "luci/Pass/RemoveRedundantTransposePass.h"
49 #include "luci/Pass/RemoveRedundantQuantizePass.h"
50 #include "luci/Pass/RemoveUnnecessaryReshapePass.h"
51 #include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
52 #include "luci/Pass/RemoveUnnecessarySlicePass.h"
53 #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
54 #include "luci/Pass/RemoveUnnecessarySplitPass.h"
55 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
56 #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
57 #include "luci/Pass/ReplaceSubWithAddPass.h"
58 #include "luci/Pass/ResolveCustomOpAddPass.h"
59 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
60 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
61 #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
62 #include "luci/Pass/ResolveCustomOpSplitVPass.h"
63 #include "luci/Pass/SparsifyTensorPass.h"
64 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
65 #include "luci/Pass/SubstitutePackToReshapePass.h"
66 #include "luci/Pass/SubstitutePadV2ToPadPass.h"
67 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
68 #include "luci/Pass/SubstituteSqueezeToReshapePass.h"
69 #include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
70 #include "luci/Pass/SubstituteTransposeToReshapePass.h"
71 #include "luci/Pass/TransformMinMaxToRelu6Pass.h"
72 #include "luci/Pass/TransformMinReluToRelu6Pass.h"
73 #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
74 // TODO add more passes
75
76 #include "luci/Pass/CircleShapeInferencePass.h"
77 #include "luci/Pass/CircleTypeInferencePass.h"
78
79 // logo passes
80 #include <logo/RemoveDeadNodeWithQueryPass.h>
81
82 #include "ModulePhase.h"
83 #include "ProgressReporter.h"
84
85 #include <luci/IR/CircleNodes.h>
86 #include <logo/Phase.h>
87 #include <pepper/csv2vec.h>
88
89 #include <memory>
90 #include <sstream>
91
92 namespace
93 {
94
95 using namespace luci;
96
97 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
98 {
99 public:
100   void enable(Algorithm) final;
101   void param(AlgorithmParameters, const std::string &) final;
102   const std::string param(AlgorithmParameters) const final;
103   bool query(Algorithm) final;
104
105 private:
106   std::vector<Algorithm> _algorithms;
107   std::map<AlgorithmParameters, const std::string> _algorithm_params;
108 };
109
110 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
111
112 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
113 {
114   _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
115 }
116
117 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
118 {
119   auto param_str = _algorithm_params.find(param);
120   if (param_str != _algorithm_params.end())
121   {
122     return param_str->second;
123   }
124   else
125   {
126     return std::string();
127   }
128 }
129
130 bool OptimizeOptionsImpl::query(Algorithm algo)
131 {
132   std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
133   if (it == _algorithms.end())
134     return false;
135
136   return true;
137 }
138
139 // TODO Make a struct for args
140 void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc)
141 {
142   logo::Phase phase;
143
144   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
145   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
146   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
147
148   // Resolve custom Ops
149   phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
150   phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
151   phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
152   phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
153   phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
154
155   // Fuse FullyConnected with Add
156   // Why we perform FuseAddWithFullyConnectedPass before ConvertNCHWToNHWCPass?
157   // FullyConnected Op's layout is not changed in ConvertNCHWToNHWCPass, while
158   // Add Op's layer is changed from NCHW to NHWC.
159   // This disables fusion of Add and FullyConnected after ConvertNCHWToNHWC.
160   if (fuse_fc)
161     phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>());
162
163   phase.emplace_back(
164     std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
165
166   ProgressReporter prog(g, logo::PhaseStrategy::Restart);
167   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
168   phase_runner.attach(&prog);
169   phase_runner.run(phase);
170 }
171
172 } // namespace
173
174 namespace luci
175 {
176
177 CircleOptimizer::Options *CircleOptimizer::options(void)
178 {
179   if (_options == nullptr)
180   {
181     _options = std::make_unique<OptimizeOptionsImpl>();
182   }
183
184   return _options.get();
185 }
186
187 void CircleOptimizer::optimize(luci::Module *m) const
188 {
189   luci::Phase phase;
190
191   // Following passes are needed everytime when other passes create new node or modify some nodes.
192   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
193   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
194
195   if (_options->query(Options::Algorithm::FuseBCQ))
196   {
197     phase.emplace_back(std::make_unique<FuseBCQPass>());
198   }
199
200   ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
201   PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
202   phase_runner.attach(&prog);
203   phase_runner.run(phase);
204 }
205
206 void CircleOptimizer::optimize(loco::Graph *g) const
207 {
208   logo::Phase phase;
209
210   // Conversion from NCHW to NHWC is done first to avoid interference with other optimizations.
211   if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
212   {
213     bool preserve_input =
214       _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_input_shape) != "true";
215     bool preserve_output =
216       _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
217
218     bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected);
219
220     convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc);
221   }
222
223   /* TRANSFORM DECLARATION BEGIN */
224   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
225
226   // Following passes are needed everytime when other passes create new node or modify some nodes.
227   phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
228   phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
229
230   if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
231   {
232     phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
233   }
234   if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
235   {
236     phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
237   }
238   if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
239   {
240     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
241   }
242   if (_options->query(Options::Algorithm::FuseMeanWithMean))
243   {
244     phase.emplace_back(std::make_unique<FuseMeanWithMeanPass>());
245   }
246   if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax))
247   {
248     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
249   }
250   if (_options->query(Options::Algorithm::ResolveCustomOpSplitV))
251   {
252     phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
253   }
254   if (_options->query(Options::Algorithm::FuseInstanceNorm))
255   {
256     phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
257   }
258   if (_options->query(Options::Algorithm::FuseBatchNormWithConv))
259   {
260     phase.emplace_back(std::make_unique<FuseBatchNormWithConvPass>());
261   }
262   if (_options->query(Options::Algorithm::FuseBatchNormWithDwConv))
263   {
264     phase.emplace_back(std::make_unique<FuseBatchNormWithDwConvPass>());
265   }
266   if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
267   {
268     phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
269   }
270   if (_options->query(Options::Algorithm::FuseAddWithFullyConnected))
271   {
272     phase.emplace_back(std::make_unique<FuseAddWithFullyConnectedPass>());
273   }
274   if (_options->query(Options::Algorithm::FuseAddWithTConv))
275   {
276     phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
277   }
278   if (_options->query(Options::Algorithm::FuseActivationFunction))
279   {
280     phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
281   }
282   if (_options->query(Options::Algorithm::FusePRelu))
283   {
284     phase.emplace_back(std::make_unique<FusePReluPass>());
285   }
286   if (_options->query(Options::Algorithm::FuseTransposeWithMean))
287   {
288     phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
289   }
290   if (_options->query(Options::Algorithm::FoldAddV2))
291   {
292     phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>());
293   }
294   if (_options->query(Options::Algorithm::FoldCast))
295   {
296     phase.emplace_back(std::make_unique<luci::FoldCastPass>());
297   }
298   if (_options->query(Options::Algorithm::FoldDensify))
299   {
300     phase.emplace_back(std::make_unique<luci::FoldDensifyPass>());
301   }
302   if (_options->query(Options::Algorithm::FoldDepthwiseConv2D))
303   {
304     phase.emplace_back(std::make_unique<luci::FoldDepthwiseConv2DPass>());
305   }
306   if (_options->query(Options::Algorithm::FoldDequantize))
307   {
308     phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
309   }
310   if (_options->query(Options::Algorithm::FoldFullyConnected))
311   {
312     phase.emplace_back(std::make_unique<luci::FoldFullyConnectedPass>());
313   }
314   if (_options->query(Options::Algorithm::FoldGather))
315   {
316     phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
317   }
318   if (_options->query(Options::Algorithm::FoldSparseToDense))
319   {
320     phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
321   }
322   if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
323   {
324     phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
325   }
326   if (_options->query(Options::Algorithm::ForwardTransposeOp))
327   {
328     phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
329   }
330   if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
331   {
332     phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
333   }
334   if (_options->query(Options::Algorithm::MakeBatchNormGammaPositive))
335   {
336     phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
337   }
338   if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32))
339   {
340     phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
341   }
342   if (_options->query(Options::Algorithm::ExpandBroadcastConst))
343   {
344     phase.emplace_back(std::make_unique<luci::ExpandBroadcastConstPass>());
345   }
346   if (_options->query(Options::Algorithm::RemoveDuplicateConst))
347   {
348     phase.emplace_back(std::make_unique<luci::RemoveDuplicateConstPass>());
349   }
350   if (_options->query(Options::Algorithm::RemoveFakeQuant))
351   {
352     phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>());
353   }
354   if (_options->query(Options::Algorithm::RemoveQuantDequantSeq))
355   {
356     phase.emplace_back(std::make_unique<luci::RemoveQuantDequantSeqPass>());
357   }
358   if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
359   {
360     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
361     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapeNetPass>());
362   }
363   if (_options->query(Options::Algorithm::RemoveUnnecessarySlice))
364   {
365     phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySlicePass>());
366   }
367   if (_options->query(Options::Algorithm::RemoveUnnecessaryStridedSlice))
368   {
369     phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryStridedSlicePass>());
370   }
371   if (_options->query(Options::Algorithm::RemoveUnnecessarySplit))
372   {
373     phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySplitPass>());
374   }
375   if (_options->query(Options::Algorithm::RemoveRedundantReshape))
376   {
377     phase.emplace_back(std::make_unique<luci::RemoveRedundantReshapePass>());
378   }
379   if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
380   {
381     phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
382   }
383   if (_options->query(Options::Algorithm::RemoveRedundantQuantize))
384   {
385     phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
386   }
387   if (_options->query(Options::Algorithm::ReplaceNonConstFCWithBatchMatMul))
388   {
389     phase.emplace_back(std::make_unique<luci::ReplaceNonConstFCWithBatchMatMulPass>());
390   }
391   if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
392   {
393     phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
394   }
395   if (_options->query(Options::Algorithm::ReplaceSubWithAdd))
396   {
397     phase.emplace_back(std::make_unique<luci::ReplaceSubWithAddPass>());
398   }
399   if (_options->query(Options::Algorithm::SubstitutePackToReshape))
400   {
401     phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
402   }
403   if (_options->query(Options::Algorithm::SubstitutePadV2ToPad))
404   {
405     phase.emplace_back(std::make_unique<luci::SubstitutePadV2ToPadPass>());
406   }
407   if (_options->query(Options::Algorithm::SubstituteSplitVToSplit))
408   {
409     phase.emplace_back(std::make_unique<luci::SubstituteSplitVToSplitPass>());
410   }
411   if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape))
412   {
413     phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>());
414   }
415   if (_options->query(Options::Algorithm::SubstituteStridedSliceToReshape))
416   {
417     phase.emplace_back(std::make_unique<luci::SubstituteStridedSliceToReshapePass>());
418   }
419   if (_options->query(Options::Algorithm::SubstituteTransposeToReshape))
420   {
421     phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>());
422   }
423   if (_options->query(Options::Algorithm::TransformMinMaxToRelu6Pass))
424   {
425     phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>());
426   }
427   if (_options->query(Options::Algorithm::TransformMinReluToRelu6Pass))
428   {
429     phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
430   }
431   if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM))
432   {
433     phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>());
434   }
435
436   /* TRANSFORM DECLARATION END */
437
438   ProgressReporter prog(g, logo::PhaseStrategy::Restart);
439   logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
440   phase_runner.attach(&prog);
441   phase_runner.run(phase);
442 }
443
444 void CircleOptimizer::sparsify(loco::Graph *g) const
445 {
446   if (_options->query(Options::Algorithm::SparsifyTensorPass))
447   {
448     std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
449     std::string str_tarversal_order =
450       _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
451     std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
452     std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
453     std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
454
455     // traversal order
456     std::vector<int32_t> traversal_order = pepper::csv_to_vector<int32_t>(str_tarversal_order);
457     // format
458     std::vector<DimensionType> format;
459     std::istringstream is(str_format);
460     for (char c; is >> c;)
461     {
462       assert(c != ',');
463       if (c == 'd')
464         format.push_back(DimensionType::DENSE);
465       else if (c == 's')
466         format.push_back(DimensionType::SPARSE_CSR);
467       if (is.peek() == ',')
468         is.ignore();
469     }
470     // block size
471     std::vector<int32_t> block_size = pepper::csv_to_vector<int32_t>(str_block_size);
472     // block map
473     std::vector<int32_t> block_map = pepper::csv_to_vector<int32_t>(str_block_map);
474
475     luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,
476                                         block_map};
477     sparsifier.run(g);
478   }
479 }
480
481 } // namespace luci