2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/CircleOptimizer.h"
19 #include "luci/Pass/FoldDequantizePass.h"
20 #include "luci/Pass/FuseActivationFunctionPass.h"
21 #include "luci/Pass/FuseAddWithTConvPass.h"
22 #include "luci/Pass/FuseBatchNormWithTConv.h"
23 #include "luci/Pass/FuseBCQPass.h"
24 #include "luci/Pass/FuseInstanceNormPass.h"
25 #include "luci/Pass/FusePreActivationBatchNormPass.h"
26 #include "luci/Pass/MakeBatchNormGammaPositivePass.h"
27 #include "luci/Pass/ResolveCustomOpAddPass.h"
28 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
29 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
30 #include "luci/Pass/RequantizePass.h"
31 #include "luci/Pass/QuantizeWithMinMaxPass.h"
32 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
33 #include "luci/Pass/SparsifyTensorPass.h"
34 // TODO add more passes
36 #include "luci/Pass/ShapeInferencePass.h"
37 #include "luci/Pass/TypeInferencePass.h"
40 #include <logo/RemoveDeadNodeWithQueryPass.h>
42 #include "ProgressReporter.h"
43 #include "CircleOptimizerUtils.h"
45 #include <luci/IR/CircleNodes.h>
46 #include <logo/Phase.h>
54 std::vector<int> parseIntFromCommadelimitedStr(std::string str)
57 std::istringstream is(str);
58 for (uint32_t i; is >> i;)
70 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
73 void enable(Algorithm) final;
74 void param(AlgorithmParameters, const std::string &) final;
75 const std::string param(AlgorithmParameters) const final;
76 bool query(Algorithm) final;
79 std::vector<Algorithm> _algorithms;
80 std::map<AlgorithmParameters, const std::string> _algorithm_params;
83 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
85 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
87 _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
90 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
92 auto param_str = _algorithm_params.find(param);
93 if (param_str != _algorithm_params.end())
95 return param_str->second;
103 bool OptimizeOptionsImpl::query(Algorithm algo)
105 std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
106 if (it == _algorithms.end())
117 CircleOptimizer::Options *CircleOptimizer::options(void)
119 if (_options == nullptr)
121 _options = std::make_unique<OptimizeOptionsImpl>();
124 return _options.get();
127 void CircleOptimizer::optimize(loco::Graph *g) const
131 /* TRANSFORM DECLARATION BEGIN */
132 if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
134 phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
136 if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
138 phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
140 if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
142 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
144 if (_options->query(Options::Algorithm::FuseInstanceNorm))
146 phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
148 if (_options->query(Options::Algorithm::FuseBCQ))
150 phase.emplace_back(std::make_unique<FuseBCQPass>());
152 if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
154 phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
156 if (_options->query(Options::Algorithm::FuseAddWithTConv))
158 phase.emplace_back(std::make_unique<FuseAddWithTConvPass>());
160 if (_options->query(Options::Algorithm::FuseActivationFunction))
162 phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
164 if (_options->query(Options::Algorithm::FoldDequantize))
166 phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
168 if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
170 phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
172 if (_options->query(Options::Algorithm::MakeBatchNormGammaPositive))
174 phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
177 // Shape inference is needed for added nodes doing above transformations
178 phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
179 phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
180 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
181 /* TRANSFORM DECLARATION END */
183 ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
184 logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
185 phase_runner.attach(&prog);
186 phase_runner.run(phase);
189 void CircleOptimizer::quantize(loco::Graph *g) const
191 // Fake quantization of weights
192 if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
194 static const std::vector<std::string> fakeq_supported_input_dtype{"float32"};
195 static const std::vector<std::string> fakeq_supported_output_dtype{"uint8", "int16"};
196 static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
198 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
199 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
200 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
202 if (!in_array(to_lower_case(input_dtype), fakeq_supported_input_dtype))
203 throw std::runtime_error("Unsupported input type. List of supported input type: " +
204 to_string(fakeq_supported_input_dtype));
206 if (!in_array(to_lower_case(output_dtype), fakeq_supported_output_dtype))
207 throw std::runtime_error("Unsupported output type. List of supported output type: " +
208 to_string(fakeq_supported_output_dtype));
210 if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
211 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
212 to_string(fakeq_supported_granularity));
214 if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
215 str_to_dtype(output_dtype) != loco::DataType::U8)
216 throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
218 // Clear existing quantparams before doing fake quantization
219 for (auto node : loco::active_nodes(loco::output_nodes(g)))
221 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
222 if (circle_node->quantparam() != nullptr)
223 circle_node->quantparam(nullptr);
226 luci::QuantizeDequantizeWeightsPass fake_quantizer(
227 str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
228 fake_quantizer.run(g);
231 // Actual quantization of weights, bias, and activation
232 if (_options->query(Options::Algorithm::QuantizeWithMinMax))
234 static const std::vector<std::string> qwmm_supported_input_dtype{"float32"};
235 static const std::vector<std::string> qwmm_supported_output_dtype{"uint8", "int16"};
236 static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
238 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
239 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
240 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
242 if (!in_array(to_lower_case(input_dtype), qwmm_supported_input_dtype))
243 throw std::runtime_error("Unsupported input type. List of supported input types: " +
244 to_string(qwmm_supported_input_dtype));
246 if (!in_array(to_lower_case(output_dtype), qwmm_supported_output_dtype))
247 throw std::runtime_error("Unsupported output type. List of supported output types: " +
248 to_string(qwmm_supported_output_dtype));
250 if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
251 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
252 to_string(qwmm_supported_granularity));
254 if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
255 str_to_dtype(output_dtype) != loco::DataType::U8)
256 throw std::runtime_error("Layer-wise quantization only supports uint8 dtype.");
258 luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
259 str_to_granularity(granularity));
264 if (_options->query(Options::Algorithm::Requantize))
266 static const std::vector<std::string> rq_supported_input_dtype{"int8"};
267 static const std::vector<std::string> rq_supported_output_dtype{"uint8"};
269 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
270 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
272 if (!in_array(to_lower_case(input_dtype), rq_supported_input_dtype))
273 throw std::runtime_error("Unsupported input type. List of supported input types: " +
274 to_string(rq_supported_input_dtype));
276 if (!in_array(to_lower_case(output_dtype), rq_supported_output_dtype))
277 throw std::runtime_error("Unsupported output type. List of supported output types: " +
278 to_string(rq_supported_output_dtype));
280 luci::RequantizePass requantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype));
286 // Do Shape/Type inference
287 phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
288 phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
290 ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
291 logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
292 phase_runner.attach(&prog);
293 phase_runner.run(phase);
296 void CircleOptimizer::sparsify(loco::Graph *g) const
298 if (_options->query(Options::Algorithm::SparsifyTensorPass))
300 std::string tensor_name = _options->param(Options::AlgorithmParameters::Sparsify_tensor_name);
301 std::string str_tarversal_order =
302 _options->param(Options::AlgorithmParameters::Sparsify_traversal_order);
303 std::string str_format = _options->param(Options::AlgorithmParameters::Sparsify_format);
304 std::string str_block_size = _options->param(Options::AlgorithmParameters::Sparsify_block_size);
305 std::string str_block_map = _options->param(Options::AlgorithmParameters::Sparsify_block_map);
308 std::vector<int32_t> traversal_order = parseIntFromCommadelimitedStr(str_tarversal_order);
310 std::vector<DimensionType> format;
311 std::istringstream is(str_format);
312 for (char c; is >> c;)
316 format.push_back(DimensionType::DENSE);
318 format.push_back(DimensionType::SPARSE_CSR);
319 if (is.peek() == ',')
323 std::vector<int32_t> block_size = parseIntFromCommadelimitedStr(str_block_size);
325 std::vector<int32_t> block_map = parseIntFromCommadelimitedStr(str_block_map);
327 luci::SparsifyTensorPass sparsifier{tensor_name, traversal_order, format, block_size,