2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/CircleOptimizer.h"
19 #include "luci/Pass/ConvertNCHWToNHWCPass.h"
20 #include "luci/Pass/ExpandBroadcastConstPass.h"
21 #include "luci/Pass/FoldAddV2Pass.h"
22 #include "luci/Pass/FoldCastPass.h"
23 #include "luci/Pass/FoldDensifyPass.h"
24 #include "luci/Pass/FoldDepthwiseConv2DPass.h"
25 #include "luci/Pass/FoldDequantizePass.h"
26 #include "luci/Pass/FoldFullyConnectedPass.h"
27 #include "luci/Pass/FoldGatherPass.h"
28 #include "luci/Pass/FoldSparseToDensePass.h"
29 #include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
30 #include "luci/Pass/ForwardTransposeOpPass.h"
31 #include "luci/Pass/FuseActivationFunctionPass.h"
32 #include "luci/Pass/FuseAddWithFullyConnectedPass.h"
33 #include "luci/Pass/FuseAddWithTConvPass.h"
34 #include "luci/Pass/FuseBatchNormWithConvPass.h"
35 #include "luci/Pass/FuseBatchNormWithDwConvPass.h"
36 #include "luci/Pass/FuseBatchNormWithTConvPass.h"
37 #include "luci/Pass/FuseBCQPass.h"
38 #include "luci/Pass/FuseInstanceNormPass.h"
39 #include "luci/Pass/FuseMeanWithMeanPass.h"
40 #include "luci/Pass/FusePreActivationBatchNormPass.h"
41 #include "luci/Pass/FusePReluPass.h"
42 #include "luci/Pass/FuseTransposeWithMeanPass.h"
43 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
44 #include "luci/Pass/RemoveDuplicateConstPass.h"
45 #include "luci/Pass/RemoveFakeQuantPass.h"
46 #include "luci/Pass/RemoveQuantDequantSeqPass.h"
47 #include "luci/Pass/RemoveRedundantReshapePass.h"
48 #include "luci/Pass/RemoveRedundantTransposePass.h"
49 #include "luci/Pass/RemoveRedundantQuantizePass.h"
50 #include "luci/Pass/RemoveUnnecessaryReshapePass.h"
51 #include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
52 #include "luci/Pass/RemoveUnnecessarySlicePass.h"
53 #include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
54 #include "luci/Pass/RemoveUnnecessarySplitPass.h"
55 #include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
56 #include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
57 #include "luci/Pass/ReplaceSubWithAddPass.h"
58 #include "luci/Pass/ResolveCustomOpAddPass.h"
59 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
60 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
61 #include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
62 #include "luci/Pass/ResolveCustomOpSplitVPass.h"
63 #include "luci/Pass/SparsifyTensorPass.h"
64 #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
65 #include "luci/Pass/SubstitutePackToReshapePass.h"
66 #include "luci/Pass/SubstitutePadV2ToPadPass.h"
67 #include "luci/Pass/SubstituteSplitVToSplitPass.h"
68 #include "luci/Pass/SubstituteSqueezeToReshapePass.h"
69 #include "luci/Pass/SubstituteStridedSliceToReshapePass.h"
70 #include "luci/Pass/SubstituteTransposeToReshapePass.h"
71 #include "luci/Pass/TransformMinMaxToRelu6Pass.h"
72 #include "luci/Pass/TransformMinReluToRelu6Pass.h"
73 #include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
74 // TODO add more passes
76 #include "luci/Pass/CircleShapeInferencePass.h"
77 #include "luci/Pass/CircleTypeInferencePass.h"
80 #include <logo/RemoveDeadNodeWithQueryPass.h>
82 #include "ModulePhase.h"
83 #include "ProgressReporter.h"
85 #include <luci/IR/CircleNodes.h>
86 #include <logo/Phase.h>
87 #include <pepper/csv2vec.h>
97 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
100 void enable(Algorithm) final;
101 void param(AlgorithmParameters, const std::string &) final;
102 const std::string param(AlgorithmParameters) const final;
103 bool query(Algorithm) final;
106 std::vector<Algorithm> _algorithms;
107 std::map<AlgorithmParameters, const std::string> _algorithm_params;
110 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
112 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
114 _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
117 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
119 auto param_str = _algorithm_params.find(param);
120 if (param_str != _algorithm_params.end())
122 return param_str->second;
126 return std::string();
130 bool OptimizeOptionsImpl::query(Algorithm algo)
132 std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
133 if (it == _algorithms.end())
139 // TODO Make a struct for args
140 void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc)
144 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
145 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
146 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
148 // Resolve custom Ops
149 phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
150 phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
151 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
152 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
153 phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
155 // Fuse FullyConnected with Add
156 // Why we perform FuseAddWithFullyConnectedPass before ConvertNCHWToNHWCPass?
157 // FullyConnected Op's layout is not changed in ConvertNCHWToNHWCPass, while
158 // Add Op's layer is changed from NCHW to NHWC.
159 // This disables fusion of Add and FullyConnected after ConvertNCHWToNHWC.
161 phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>());
164 std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
166 ProgressReporter prog(g, logo::PhaseStrategy::Restart);
167 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
168 phase_runner.attach(&prog);
169 phase_runner.run(phase);
177 CircleOptimizer::Options *CircleOptimizer::options(void)
179 if (_options == nullptr)
181 _options = std::make_unique<OptimizeOptionsImpl>();
184 return _options.get();
187 void CircleOptimizer::optimize(luci::Module *m) const
191 // Following passes are needed everytime when other passes create new node or modify some nodes.
192 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
193 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
195 if (_options->query(Options::Algorithm::FuseBCQ))
197 phase.emplace_back(std::make_unique<FuseBCQPass>());
200 ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
201 PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
202 phase_runner.attach(&prog);
203 phase_runner.run(phase);
206 void CircleOptimizer::optimize(loco::Graph *g) const
210 // Conversion from NCHW to NHWC is done first to avoid interference with other optimizations.
211 if (_options->query(Options::Algorithm::ConvertNCHWToNHWC))
213 bool preserve_input =
214 _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_input_shape) != "true";
215 bool preserve_output =
216 _options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
218 bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected);
220 convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc);
223 /* TRANSFORM DECLARATION BEGIN */
224 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
226 // Following passes are needed everytime when other passes create new node or modify some nodes.
227 phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
228 phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
230 if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
232 phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
234 if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
236 phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
238 if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
240 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
242 if (_options->query(Options::Algorithm::FuseMeanWithMean))
244 phase.emplace_back(std::make_unique<FuseMeanWithMeanPass>());
246 if (_options->query(Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax))
248 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
250 if (_options->query(Options::Algorithm::ResolveCustomOpSplitV))
252 phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
254 if (_options->query(Options::Algorithm::FuseInstanceNorm))
256 phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
258 if (_options->query(Options::Algorithm::FuseBatchNormWithConv))
260 phase.emplace_back(std::make_unique<FuseBatchNormWithConvPass>());
262 if (_options->query(Options::Algorithm::FuseBatchNormWithDwConv))
264 phase.emplace_back(std::make_unique<FuseBatchNormWithDwConvPass>());
266 if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
268 phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
270 if (_options->query(Options::Algorithm::FuseAddWithFullyConnected))
272 phase.emplace_back(std::make_unique<FuseAddWithFullyConnectedPass>());
274 if (_options->query(Options::Algorithm::FuseAddWithTConv))
276 phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
278 if (_options->query(Options::Algorithm::FuseActivationFunction))
280 phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
282 if (_options->query(Options::Algorithm::FusePRelu))
284 phase.emplace_back(std::make_unique<FusePReluPass>());
286 if (_options->query(Options::Algorithm::FuseTransposeWithMean))
288 phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
290 if (_options->query(Options::Algorithm::FoldAddV2))
292 phase.emplace_back(std::make_unique<luci::FoldAddV2Pass>());
294 if (_options->query(Options::Algorithm::FoldCast))
296 phase.emplace_back(std::make_unique<luci::FoldCastPass>());
298 if (_options->query(Options::Algorithm::FoldDensify))
300 phase.emplace_back(std::make_unique<luci::FoldDensifyPass>());
302 if (_options->query(Options::Algorithm::FoldDepthwiseConv2D))
304 phase.emplace_back(std::make_unique<luci::FoldDepthwiseConv2DPass>());
306 if (_options->query(Options::Algorithm::FoldDequantize))
308 phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
310 if (_options->query(Options::Algorithm::FoldFullyConnected))
312 phase.emplace_back(std::make_unique<luci::FoldFullyConnectedPass>());
314 if (_options->query(Options::Algorithm::FoldGather))
316 phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
318 if (_options->query(Options::Algorithm::FoldSparseToDense))
320 phase.emplace_back(std::make_unique<luci::FoldSparseToDensePass>());
322 if (_options->query(Options::Algorithm::ForwardReshapeToUnaryOp))
324 phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
326 if (_options->query(Options::Algorithm::ForwardTransposeOp))
328 phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
330 if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
332 phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
334 if (_options->query(Options::Algorithm::MakeBatchNormGammaPositive))
336 phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
338 if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32))
340 phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
342 if (_options->query(Options::Algorithm::ExpandBroadcastConst))
344 phase.emplace_back(std::make_unique<luci::ExpandBroadcastConstPass>());
346 if (_options->query(Options::Algorithm::RemoveDuplicateConst))
348 phase.emplace_back(std::make_unique<luci::RemoveDuplicateConstPass>());
350 if (_options->query(Options::Algorithm::RemoveFakeQuant))
352 phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>());
354 if (_options->query(Options::Algorithm::RemoveQuantDequantSeq))
356 phase.emplace_back(std::make_unique<luci::RemoveQuantDequantSeqPass>());
358 if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
360 phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
361 phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapeNetPass>());
363 if (_options->query(Options::Algorithm::RemoveUnnecessarySlice))
365 phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySlicePass>());
367 if (_options->query(Options::Algorithm::RemoveUnnecessaryStridedSlice))
369 phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryStridedSlicePass>());
371 if (_options->query(Options::Algorithm::RemoveUnnecessarySplit))
373 phase.emplace_back(std::make_unique<luci::RemoveUnnecessarySplitPass>());
375 if (_options->query(Options::Algorithm::RemoveRedundantReshape))
377 phase.emplace_back(std::make_unique<luci::RemoveRedundantReshapePass>());
379 if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
381 phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
383 if (_options->query(Options::Algorithm::RemoveRedundantQuantize))
385 phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
387 if (_options->query(Options::Algorithm::ReplaceNonConstFCWithBatchMatMul))
389 phase.emplace_back(std::make_unique<luci::ReplaceNonConstFCWithBatchMatMulPass>());
391 if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
393 phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
395 if (_options->query(Options::Algorithm::ReplaceSubWithAdd))
397 phase.emplace_back(std::make_unique<luci::ReplaceSubWithAddPass>());
399 if (_options->query(Options::Algorithm::SubstitutePackToReshape))
401 phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
403 if (_options->query(Options::Algorithm::SubstitutePadV2ToPad))
405 phase.emplace_back(std::make_unique<luci::SubstitutePadV2ToPadPass>());
407 if (_options->query(Options::Algorithm::SubstituteSplitVToSplit))
409 phase.emplace_back(std::make_unique<luci::SubstituteSplitVToSplitPass>());
411 if (_options->query(Options::Algorithm::SubstituteSqueezeToReshape))
413 phase.emplace_back(std::make_unique<luci::SubstituteSqueezeToReshapePass>());
415 if (_options->query(Options::Algorithm::SubstituteStridedSliceToReshape))
417 phase.emplace_back(std::make_unique<luci::SubstituteStridedSliceToReshapePass>());
419 if (_options->query(Options::Algorithm::SubstituteTransposeToReshape))
421 phase.emplace_back(std::make_unique<luci::SubstituteTransposeToReshapePass>());
423 if (_options->query(Options::Algorithm::TransformMinMaxToRelu6Pass))
425 phase.emplace_back(std::make_unique<luci::TransformMinMaxToRelu6Pass>());
427 if (_options->query(Options::Algorithm::TransformMinReluToRelu6Pass))
429 phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
431 if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM))
433 phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>());
436 /* TRANSFORM DECLARATION END */
438 ProgressReporter prog(g, logo::PhaseStrategy::Restart);
439 logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
440 phase_runner.attach(&prog);
441 phase_runner.run(phase);
444 void CircleOptimizer::sparsify(loco::Graph *g) const
446 if (_options->query(Options::Algorithm::SparsifyTensorPass))
448 std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
449 std::string str_tarversal_order =
450 _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
451 std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
452 std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
453 std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
456 std::vector<int32_t> traversal_order = pepper::csv_to_vector<int32_t>(str_tarversal_order);
458 std::vector<DimensionType> format;
459 std::istringstream is(str_format);
460 for (char c; is >> c;)
464 format.push_back(DimensionType::DENSE);
466 format.push_back(DimensionType::SPARSE_CSR);
467 if (is.peek() == ',')
471 std::vector<int32_t> block_size = pepper::csv_to_vector<int32_t>(str_block_size);
473 std::vector<int32_t> block_map = pepper::csv_to_vector<int32_t>(str_block_map);
475 luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,