2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/CircleOptimizer.h"
19 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
20 #include "luci/Pass/ExpandBroadcastConstPass.h"
21 #include "luci/Pass/FoldAddV2Pass.h"
22 #include "luci/Pass/FoldCastPass.h"
23 #include "luci/Pass/FoldDensifyPass.h"
24 #include "luci/Pass/FoldDepthwiseConv2DPass.h"
25 #include "luci/Pass/FoldDequantizePass.h"
26 #include "luci/Pass/FoldFullyConnectedPass.h"
27 #include "luci/Pass/FoldGatherPass.h"
28 #include "luci/Pass/FoldSparseToDensePass.h"
29 #include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
30 #include "luci/Pass/ForwardTransposeOpPass.h"
31 #include "luci/Pass/FuseActivationFunctionPass.h"
32 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
33 #include "luci/Pass/FuseAddWithTConvPass.h"
34 #include "luci/Pass/FuseBatchNormWithConvPass.h"
35 #include "luci/Pass/FuseBatchNormWithDwConvPass.h"
36 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
37 #include "luci/Pass/FuseBCQPass.h"
38 #include "luci/Pass/FuseInstanceNormPass.h"
39 #include "luci/Pass/FuseMeanWithMeanPass.h"
40 #include "luci/Pass/FusePreActivationBatchNormPass.h"
41 #include "luci/Pass/FusePReluPass.h"
42 #include "luci/Pass/FuseGeluPass.h"
43 #include "luci/Pass/FuseTransposeWithMeanPass.h"
44 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
45 #include "luci/Pass/RemoveDuplicateConstPass.h"
46 #include "luci/Pass/RemoveFakeQuantPass.h"
47 #include "luci/Pass/RemoveQuantDequantSeqPass.h"
48 #include "luci/Pass/RemoveRedundantReshapePass.h"
49 #include "luci/Pass/RemoveRedundantTransposePass.h"
50 #include "luci/Pass/RemoveRedundantQuantizePass.h"
51 #include "luci/Pass/RemoveUnnecessaryReshapePass.h"
52 #include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
53 #include "luci/Pass/RemoveUnnecessarySlicePass.h"
54 #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
55 #include "luci/Pass/RemoveUnnecessarySplitPass.h"
56 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
57 #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
58 #include "luci/Pass/ReplaceSubWithAddPass.h"
59 #include "luci/Pass/ResolveCustomOpAddPass.h"
60 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
61 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
62 #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
63 #include "luci/Pass/ResolveCustomOpSplitVPass.h"
64 #include "luci/Pass/SparsifyTensorPass.h"
65 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
66 #include "luci/Pass/SubstitutePackToReshapePass.h"
67 #include "luci/Pass/SubstitutePadV2ToPadPass.h"
68 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
69 #include "luci/Pass/SubstituteSqueezeToReshapePass.h"
70 #include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
71 #include "luci/Pass/SubstituteTransposeToReshapePass.h"
72 #include "luci/Pass/TransformMinMaxToRelu6Pass.h"
73 #include "luci/Pass/TransformMinReluToRelu6Pass.h"
74 #include "luci/Pass/DecomposeHardSwishPass.h"
75 #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
76 // TODO add more passes
78 #include "luci/Pass/CircleShapeInferencePass.h"
79 #include "luci/Pass/CircleTypeInferencePass.h"
82 #include <logo/RemoveDeadNodeWithQueryPass.h>
84 #include "ModulePhase.h"
85 #include "ProgressReporter.h"
87 #include <luci/IR/CircleNodes.h>
88 #include <logo/Phase.h>
89 #include <pepper/csv2vec.h>
99 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
102 void enable(Algorithm) final;
103 void param(AlgorithmParameters, const std::string &) final;
104 const std::string param(AlgorithmParameters) const final;
105 bool query(Algorithm) final;
108 std::vector<Algorithm> _algorithms;
109 std::map<AlgorithmParameters, const std::string> _algorithm_params;
112 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
114 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
116 _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
119 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
121 auto param_str = _algorithm_params.find(param);
122 if (param_str != _algorithm_params.end())
124 return param_str->second;
128 return std::string();
132 bool OptimizeOptionsImpl::query(Algorithm algo)
134 std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
135 if (it == _algorithms.end())
141 // TODO Make a struct for args
142 void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc,
147 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
148 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
149 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
151 // Resolve custom Ops
152 phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
153 phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
154 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
155 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
156 phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
158 // Fuse FullyConnected with Add
159 // Why we perform FuseAddWithFullyConnectedPass before ConvertNCHWToNHWCPass?
160 // FullyConnected Op's layout is not changed in ConvertNCHWToNHWCPass, while
161 // Add Op's layer is changed from NCHW to NHWC.
162 // This disables fusion of Add and FullyConnected after ConvertNCHWToNHWC.
164 phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>());
166 // Fuse decomposed ops to Gelu Op
167 // Why here? ConverNCHWToNHWCPass inserts additional Ops, so it is better to fuse
170 phase.emplace_back(std::make_unique<luci::FuseGeluPass>());
173 std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
175 ProgressReporter prog(g, logo::PhaseStrategy::Restart);
176 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
177 phase_runner.attach(&prog);
178 phase_runner.run(phase);
186 CircleOptimizer::Options *CircleOptimizer::options(void)
188 if (_options == nullptr)
190 _options = std::make_unique<OptimizeOptionsImpl>();
193 return _options.get();
196 void CircleOptimizer::optimize(luci::Module *m) const
200 // Following passes are needed everytime when other passes create new node or modify some nodes.
201 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
202 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
204 if (_options->query(Options::Algorithm::FuseBCQ))
206 phase.emplace_back(std::make_unique<FuseBCQPass>());
209 ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
210 PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
211 phase_runner.attach(&prog);
212 phase_runner.run(phase);
215 void CircleOptimizer::optimize(loco::Graph *g) const
219 // Conversion from NCHW to NHWC is done first to avoid interference with other optimizations.
220 if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
222 bool preserve_input =
223 _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_input_shape) != "true";
224 bool preserve_output =
225 _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
227 bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected);
228 bool fuse_gelu = _options->query(Options::Algorithm::FuseGelu);
230 convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc, fuse_gelu);
233 /* TRANSFORM DECLARATION BEGIN */
234 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
236 // Following passes are needed everytime when other passes create new node or modify some nodes.
237 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
238 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
240 if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
242 phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
244 if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
246 phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
248 if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
250 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
252 if (_options->query(Options::Algorithm::FuseMeanWithMean))
254 phase.emplace_back(std::make_unique<FuseMeanWithMeanPass>());
256 if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax))
258 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
260 if (_options->query(Options::Algorithm::ResolveCustomOpSplitV))
262 phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
264 if (_options->query(Options::Algorithm::FuseInstanceNorm))
266 phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
268 if (_options->query(Options::Algorithm::FuseBatchNormWithConv))
270 phase.emplace_back(std::make_unique<FuseBatchNormWithConvPass>());
272 if (_options->query(Options::Algorithm::FuseBatchNormWithDwConv))
274 phase.emplace_back(std::make_unique<FuseBatchNormWithDwConvPass>());
276 if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
278 phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
280 if (_options->query(Options::Algorithm::FuseAddWithFullyConnected))
282 phase.emplace_back(std::make_unique<FuseAddWithFullyConnectedPass>());
284 if (_options->query(Options::Algorithm::FuseAddWithTConv))
286 phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
288 if (_options->query(Options::Algorithm::FuseActivationFunction))
290 phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
292 if (_options->query(Options::Algorithm::FusePRelu))
294 phase.emplace_back(std::make_unique<FusePReluPass>());
296 if (_options->query(Options::Algorithm::FuseGelu))
298 phase.emplace_back(std::make_unique<FuseGeluPass>());
300 if (_options->query(Options::Algorithm::FuseTransposeWithMean))
302 phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
304 if (_options->query(Options::Algorithm::FoldAddV2))
306 phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>());
308 if (_options->query(Options::Algorithm::FoldCast))
310 phase.emplace_back(std::make_unique<luci::FoldCastPass>());
312 if (_options->query(Options::Algorithm::FoldDensify))
314 phase.emplace_back(std::make_unique<luci::FoldDensifyPass>());
316 if (_options->query(Options::Algorithm::FoldDepthwiseConv2D))
318 phase.emplace_back(std::make_unique<luci::FoldDepthwiseConv2DPass>());
320 if (_options->query(Options::Algorithm::FoldDequantize))
322 phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
324 if (_options->query(Options::Algorithm::FoldFullyConnected))
326 phase.emplace_back(std::make_unique<luci::FoldFullyConnectedPass>());
328 if (_options->query(Options::Algorithm::FoldGather))
330 phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
332 if (_options->query(Options::Algorithm::FoldSparseToDense))
334 phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
336 if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
338 phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
340 if (_options->query(Options::Algorithm::MakeBatchNormGammaPositive))
342 phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
344 if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32))
346 phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
348 if (_options->query(Options::Algorithm::ExpandBroadcastConst))
350 phase.emplace_back(std::make_unique<luci::ExpandBroadcastConstPass>());
352 if (_options->query(Options::Algorithm::RemoveDuplicateConst))
354 phase.emplace_back(std::make_unique<luci::RemoveDuplicateConstPass>());
356 if (_options->query(Options::Algorithm::RemoveFakeQuant))
358 phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>());
360 if (_options->query(Options::Algorithm::RemoveQuantDequantSeq))
362 phase.emplace_back(std::make_unique<luci::RemoveQuantDequantSeqPass>());
364 if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
366 phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
367 phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapeNetPass>());
369 if (_options->query(Options::Algorithm::RemoveUnnecessarySlice))
371 phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySlicePass>());
373 if (_options->query(Options::Algorithm::RemoveUnnecessaryStridedSlice))
375 phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryStridedSlicePass>());
377 if (_options->query(Options::Algorithm::RemoveUnnecessarySplit))
379 phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySplitPass>());
381 if (_options->query(Options::Algorithm::RemoveRedundantReshape))
383 phase.emplace_back(std::make_unique<luci::RemoveRedundantReshapePass>());
385 if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
387 phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
389 if (_options->query(Options::Algorithm::RemoveRedundantQuantize))
391 phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
393 if (_options->query(Options::Algorithm::ReplaceNonConstFCWithBatchMatMul))
395 phase.emplace_back(std::make_unique<luci::ReplaceNonConstFCWithBatchMatMulPass>());
397 if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
399 phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
401 if (_options->query(Options::Algorithm::ReplaceSubWithAdd))
403 phase.emplace_back(std::make_unique<luci::ReplaceSubWithAddPass>());
405 if (_options->query(Options::Algorithm::SubstitutePackToReshape))
407 phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
409 if (_options->query(Options::Algorithm::SubstitutePadV2ToPad))
411 phase.emplace_back(std::make_unique<luci::SubstitutePadV2ToPadPass>());
413 if (_options->query(Options::Algorithm::SubstituteSplitVToSplit))
415 phase.emplace_back(std::make_unique<luci::SubstituteSplitVToSplitPass>());
417 if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape))
419 phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>());
421 if (_options->query(Options::Algorithm::SubstituteStridedSliceToReshape))
423 phase.emplace_back(std::make_unique<luci::SubstituteStridedSliceToReshapePass>());
425 if (_options->query(Options::Algorithm::SubstituteTransposeToReshape))
427 phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>());
429 if (_options->query(Options::Algorithm::TransformMinMaxToRelu6Pass))
431 phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>());
433 if (_options->query(Options::Algorithm::TransformMinReluToRelu6Pass))
435 phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
437 if (_options->query(Options::Algorithm::DecomposeHardSwishPass))
439 phase.emplace_back(std::make_unique<luci::DecomposeHardSwishPass>());
441 if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM))
443 phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>());
445 // Forward Reshape/Transpose is done after
446 // 1. SubstituteXXXToReshape
447 // 2. RemoveRedundantReshape/Transpose
448 // See https://github.com/Samsung/ONE/pull/10596 for more details
449 if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
451 phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
453 if (_options->query(Options::Algorithm::ForwardTransposeOp))
455 phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
458 /* TRANSFORM DECLARATION END */
460 ProgressReporter prog(g, logo::PhaseStrategy::Restart);
461 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
462 phase_runner.attach(&prog);
463 phase_runner.run(phase);
466 void CircleOptimizer::sparsify(loco::Graph *g) const
468 if (_options->query(Options::Algorithm::SparsifyTensorPass))
470 std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
471 std::string str_tarversal_order =
472 _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
473 std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
474 std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
475 std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
478 std::vector<int32_t> traversal_order = pepper::csv_to_vector<int32_t>(str_tarversal_order);
480 std::vector<DimensionType> format;
481 std::istringstream is(str_format);
482 for (char c; is >> c;)
486 format.push_back(DimensionType::DENSE);
488 format.push_back(DimensionType::SPARSE_CSR);
489 if (is.peek() == ',')
493 std::vector<int32_t> block_size = pepper::csv_to_vector<int32_t>(str_block_size);
495 std::vector<int32_t> block_map = pepper::csv_to_vector<int32_t>(str_block_map);
497 luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,