2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/CircleOptimizer.h"
19 #include "luci/Pass/FuseBatchNormWithTConv.h"
20 #include "luci/Pass/FuseBCQPass.h"
21 #include "luci/Pass/FuseInstanceNormPass.h"
22 #include "luci/Pass/ResolveCustomOpAddPass.h"
23 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
24 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
25 #include "luci/Pass/RequantizePass.h"
26 #include "luci/Pass/QuantizeWithMinMaxPass.h"
27 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
28 // TODO add more passes
30 #include "luci/Pass/ShapeInferencePass.h"
31 #include "luci/Pass/TypeInferencePass.h"
34 #include <logo/RemoveDeadNodeWithQueryPass.h>
36 #include "ProgressReporter.h"
37 #include "CircleOptimizerUtils.h"
39 #include <luci/IR/CircleNodes.h>
40 #include <logo/Phase.h>
49 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
52 void enable(Algorithm) final;
53 void param(AlgorithmParameters, const std::string &) final;
54 const std::string param(AlgorithmParameters) const final;
55 bool query(Algorithm) final;
58 std::vector<Algorithm> _algorithms;
59 std::map<AlgorithmParameters, const std::string> _algorithm_params;
62 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
64 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
66 _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
69 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
71 auto param_str = _algorithm_params.find(param);
72 if (param_str != _algorithm_params.end())
74 return param_str->second;
82 bool OptimizeOptionsImpl::query(Algorithm algo)
84 std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
85 if (it == _algorithms.end())
96 CircleOptimizer::Options *CircleOptimizer::options(void)
98 if (_options == nullptr)
100 _options = std::make_unique<OptimizeOptionsImpl>();
103 return _options.get();
106 void CircleOptimizer::optimize(loco::Graph *g) const
110 /* TRANSFORM DECLARATION BEGIN */
111 if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
113 phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
115 if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
117 phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
119 if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
121 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
123 if (_options->query(Options::Algorithm::FuseInstanceNorm))
125 phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
127 if (_options->query(Options::Algorithm::FuseBCQ))
129 phase.emplace_back(std::make_unique<FuseBCQPass>());
131 if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
133 phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
136 // Shape inference is needed for added nodes doing above transformations
137 phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
138 phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
139 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
140 /* TRANSFORM DECLARATION END */
142 ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
143 logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
144 phase_runner.attach(&prog);
145 phase_runner.run(phase);
148 void CircleOptimizer::quantize(loco::Graph *g) const
150 // Fake quantization of weights
151 if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
153 static const std::vector<std::string> fakeq_supported_input_dtype{"float32"};
154 static const std::vector<std::string> fakeq_supported_output_dtype{"uint8"};
155 static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
157 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
158 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
159 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
161 if (!in_array(to_lower_case(input_dtype), fakeq_supported_input_dtype))
162 throw std::runtime_error("Unsupported input type. List of supported input type: " +
163 to_string(fakeq_supported_input_dtype));
165 if (!in_array(to_lower_case(output_dtype), fakeq_supported_output_dtype))
166 throw std::runtime_error("Unsupported output type. List of supported output type: " +
167 to_string(fakeq_supported_output_dtype));
169 if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
170 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
171 to_string(fakeq_supported_granularity));
173 // Clear existing quantparams before doing fake quantization
174 for (auto node : loco::active_nodes(loco::output_nodes(g)))
176 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
177 if (circle_node->quantparam() != nullptr)
178 circle_node->quantparam(nullptr);
181 luci::QuantizeDequantizeWeightsPass fake_quantizer(
182 str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
183 fake_quantizer.run(g);
186 // Actual quantization of weights, bias, and activation
187 if (_options->query(Options::Algorithm::QuantizeWithMinMax))
189 static const std::vector<std::string> qwmm_supported_input_dtype{"float32"};
190 static const std::vector<std::string> qwmm_supported_output_dtype{"uint8"};
191 static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
193 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
194 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
195 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
197 if (!in_array(to_lower_case(input_dtype), qwmm_supported_input_dtype))
198 throw std::runtime_error("Unsupported input type. List of supported input types: " +
199 to_string(qwmm_supported_input_dtype));
201 if (!in_array(to_lower_case(output_dtype), qwmm_supported_output_dtype))
202 throw std::runtime_error("Unsupported output type. List of supported output types: " +
203 to_string(qwmm_supported_output_dtype));
205 if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
206 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
207 to_string(qwmm_supported_granularity));
209 luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
210 str_to_granularity(granularity));
215 if (_options->query(Options::Algorithm::Requantize))
217 static const std::vector<std::string> rq_supported_input_dtype{"int8"};
218 static const std::vector<std::string> rq_supported_output_dtype{"uint8"};
220 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
221 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
223 if (!in_array(to_lower_case(input_dtype), rq_supported_input_dtype))
224 throw std::runtime_error("Unsupported input type. List of supported input types: " +
225 to_string(rq_supported_input_dtype));
227 if (!in_array(to_lower_case(output_dtype), rq_supported_output_dtype))
228 throw std::runtime_error("Unsupported output type. List of supported output types: " +
229 to_string(rq_supported_output_dtype));
231 luci::RequantizePass requantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype));
237 // Do Shape/Type inference
238 phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
239 phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
241 ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
242 logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
243 phase_runner.attach(&prog);
244 phase_runner.run(phase);