2edf7a9c63c2e9351a14de2fb3e2217718b0265d
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / CircleOptimizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/CircleOptimizer.h"
18
19 #include "luci/Pass/FuseBCQPass.h"
20 #include "luci/Pass/FuseInstanceNormPass.h"
21 #include "luci/Pass/ResolveCustomOpAddPass.h"
22 #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
23 #include "luci/Pass/ResolveCustomOpMatMulPass.h"
24 #include "luci/Pass/QuantizeWithMinMaxPass.h"
25 #include "luci/Pass/QuantizeDequantizeWeightsPass.h"
26 // TODO add more passes
27
28 #include "luci/Pass/ShapeInferencePass.h"
29 #include "luci/Pass/TypeInferencePass.h"
30
31 // logo passes
32 #include <logo/RemoveDeadNodeWithQueryPass.h>
33
34 #include "ProgressReporter.h"
35 #include "CircleOptimizerUtils.h"
36
37 #include <logo/Phase.h>
38
39 #include <memory>
40
41 namespace
42 {
43
44 using namespace luci;
45
46 class OptimizeOptionsImpl final : public luci::CircleOptimizer::Options
47 {
48 public:
49   void enable(Algorithm) final;
50   void param(AlgorithmParameters, const std::string &) final;
51   const std::string param(AlgorithmParameters) const final;
52   bool query(Algorithm) final;
53
54 private:
55   std::vector<Algorithm> _algorithms;
56   std::map<AlgorithmParameters, const std::string> _algorithm_params;
57 };
58
59 void OptimizeOptionsImpl::enable(Algorithm algo) { _algorithms.push_back(algo); }
60
61 void OptimizeOptionsImpl::param(AlgorithmParameters param, const std::string &str)
62 {
63   _algorithm_params.insert(std::pair<AlgorithmParameters, const std::string>(param, str));
64 }
65
66 const std::string OptimizeOptionsImpl::param(AlgorithmParameters param) const
67 {
68   auto param_str = _algorithm_params.find(param);
69   if (param_str != _algorithm_params.end())
70   {
71     return param_str->second;
72   }
73   else
74   {
75     return std::string();
76   }
77 }
78
79 bool OptimizeOptionsImpl::query(Algorithm algo)
80 {
81   std::vector<Algorithm>::iterator it = std::find(_algorithms.begin(), _algorithms.end(), algo);
82   if (it == _algorithms.end())
83     return false;
84
85   return true;
86 }
87
88 } // namespace
89
90 namespace luci
91 {
92
93 CircleOptimizer::Options *CircleOptimizer::options(void)
94 {
95   if (_options == nullptr)
96   {
97     _options = std::make_unique<OptimizeOptionsImpl>();
98   }
99
100   return _options.get();
101 }
102
103 void CircleOptimizer::optimize(loco::Graph *g) const
104 {
105   logo::Phase phase;
106
107   /* TRANSFORM DECLARATION BEGIN */
108   if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
109   {
110     phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
111   }
112   if (_options->query(Options::Algorithm::ResolveCustomOpBatchMatMul))
113   {
114     phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
115   }
116   if (_options->query(Options::Algorithm::ResolveCustomOpMatMul))
117   {
118     phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
119   }
120   if (_options->query(Options::Algorithm::FuseInstanceNorm))
121   {
122     phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
123   }
124   if (_options->query(Options::Algorithm::FuseBCQ))
125   {
126     phase.emplace_back(std::make_unique<FuseBCQPass>());
127   }
128
129   // Shape inference is needed for added nodes doing above transformations
130   phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
131   phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
132   phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
133   /* TRANSFORM DECLARATION END */
134
135   ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
136   logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
137   phase_runner.attach(&prog);
138   phase_runner.run(phase);
139 }
140
141 void CircleOptimizer::quantize(loco::Graph *g) const
142 {
143   // Fake quantization of weights
144   if (_options->query(Options::Algorithm::QuantizeDequantizeWeights))
145   {
146     static const std::vector<std::string> fakeq_supported_input_dtype{"float32"};
147     static const std::vector<std::string> fakeq_supported_output_dtype{"uint8"};
148     static const std::vector<std::string> fakeq_supported_granularity{"layer", "channel"};
149
150     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
151     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
152     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
153
154     if (!in_array(to_lower_case(input_dtype), fakeq_supported_input_dtype))
155       throw std::runtime_error("Unsupported input type. List of supported input type: " +
156                                to_string(fakeq_supported_input_dtype));
157
158     if (!in_array(to_lower_case(output_dtype), fakeq_supported_output_dtype))
159       throw std::runtime_error("Unsupported output type. List of supported output type: " +
160                                to_string(fakeq_supported_output_dtype));
161
162     if (!in_array(to_lower_case(granularity), fakeq_supported_granularity))
163       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
164                                to_string(fakeq_supported_granularity));
165
166     luci::QuantizeDequantizeWeightsPass fake_quantizer(
167         str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
168     fake_quantizer.run(g);
169   }
170
171   // Actual quantization of weights, bias, and activation
172   if (_options->query(Options::Algorithm::QuantizeWithMinMax))
173   {
174     static const std::vector<std::string> qwmm_supported_input_dtype{"float32"};
175     static const std::vector<std::string> qwmm_supported_output_dtype{"uint8"};
176     static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
177
178     auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
179     auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
180     auto granularity = _options->param(Options::AlgorithmParameters::Quantize_granularity);
181
182     if (!in_array(to_lower_case(input_dtype), qwmm_supported_input_dtype))
183       throw std::runtime_error("Unsupported input type. List of supported input types: " +
184                                to_string(qwmm_supported_input_dtype));
185
186     if (!in_array(to_lower_case(output_dtype), qwmm_supported_output_dtype))
187       throw std::runtime_error("Unsupported output type. List of supported output types: " +
188                                to_string(qwmm_supported_output_dtype));
189
190     if (!in_array(to_lower_case(granularity), qwmm_supported_granularity))
191       throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
192                                to_string(qwmm_supported_granularity));
193
194     luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
195                                            str_to_granularity(granularity));
196     quantizer.run(g);
197   }
198
199   logo::Phase phase;
200
201   // Do Shape/Type inference
202   phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
203   phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
204
205   ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
206   logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
207   phase_runner.attach(&prog);
208   phase_runner.run(phase);
209 }
210
211 } // namespace luci