2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/CircleOptimizer.h"
19 #include "luci/Pass/FuseBCQPass.h"
20 #include "luci/Pass/FuseInstanceNormPass.h"
21 #include "luci/Pass/ResolveCustomOpAddPass.h"
22 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
23 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
24 #include "luci/Pass/QuantizeWithMinMaxPass.h"
25 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
26 // TODO add more passes
28 #include "luci/Pass/ShapeInferencePass.h"
29 #include "luci/Pass/TypeInferencePass.h"
32 #include <logo/RemoveDeadNodeWithQueryPass.h>
34 #include "ProgressReporter.h"
35 #include "CircleOptimizerUtils.h"
37 #include <logo/Phase.h>
46 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
49 void enable(Algorithm) final;
50 void param(AlgorithmParameters, const std::string &) final;
51 const std::string param(AlgorithmParameters) const final;
52 bool query(Algorithm) final;
55 std::vector<Algorithm> _algorithms;
56 std::map<AlgorithmParameters, const std::string> _algorithm_params;
59 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
61 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
63 _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
66 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
68 auto param_str = _algorithm_params.find(param);
69 if (param_str != _algorithm_params.end())
71 return param_str->second;
79 bool OptimizeOptionsImpl::query(Algorithm algo)
81 std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
82 if (it == _algorithms.end())
93 CircleOptimizer::Options *CircleOptimizer::options(void)
95 if (_options == nullptr)
97 _options = std::make_unique<OptimizeOptionsImpl>();
100 return _options.get();
103 void CircleOptimizer::optimize(loco::Graph *g) const
107 /* TRANSFORM DECLARATION BEGIN */
108 if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
110 phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
112 if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
114 phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
116 if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
118 phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
120 if (_options->query(Options::Algorithm::FuseInstanceNorm))
122 phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
124 if (_options->query(Options::Algorithm::FuseBCQ))
126 phase.emplace_back(std::make_unique<FuseBCQPass>());
129 // Shape inference is needed for added nodes doing above transformations
130 phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
131 phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
132 phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
133 /* TRANSFORM DECLARATION END */
135 ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
136 logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
137 phase_runner.attach(&prog);
138 phase_runner.run(phase);
141 void CircleOptimizer::quantize(loco::Graph *g) const
143 // Fake quantization of weights
144 if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
146 static const std::vector<std::string> fakeq_supported_input_dtype{"float32"};
147 static const std::vector<std::string> fakeq_supported_output_dtype{"uint8"};
148 static const std::vector<std::string> fakeq_supported_granularity{"layer"};
150 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
151 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
152 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
154 if (!in_array(to_lower_case(input_dtype), fakeq_supported_input_dtype))
155 throw std::runtime_error("Unsupported input type. List of supported input type: " +
156 to_string(fakeq_supported_input_dtype));
158 if (!in_array(to_lower_case(output_dtype), fakeq_supported_output_dtype))
159 throw std::runtime_error("Unsupported output type. List of supported output type: " +
160 to_string(fakeq_supported_output_dtype));
162 if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
163 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
164 to_string(fakeq_supported_granularity));
166 luci::QuantizeDequantizeWeightsPass fake_quantizer(
167 str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
168 fake_quantizer.run(g);
171 // Actual quantization of weights, bias, and activation
172 if (_options->query(Options::Algorithm::QuantizeWithMinMax))
174 static const std::vector<std::string> qwmm_supported_input_dtype{"float32"};
175 static const std::vector<std::string> qwmm_supported_output_dtype{"uint8"};
176 static const std::vector<std::string> qwmm_supported_granularity{"layer"};
178 auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
179 auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
180 auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
182 if (!in_array(to_lower_case(input_dtype), qwmm_supported_input_dtype))
183 throw std::runtime_error("Unsupported input type. List of supported input types: " +
184 to_string(qwmm_supported_input_dtype));
186 if (!in_array(to_lower_case(output_dtype), qwmm_supported_output_dtype))
187 throw std::runtime_error("Unsupported output type. List of supported output types: " +
188 to_string(qwmm_supported_output_dtype));
190 if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
191 throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
192 to_string(qwmm_supported_granularity));
194 luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
195 str_to_granularity(granularity));
201 // Do Shape/Type inference
202 phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
203 phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
205 ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
206 logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
207 phase_runner.attach(&prog);
208 phase_runner.run(phase);