Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / circle-quantizer / src / CircleQuantizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <foder/FileLoader.h>
18
19 #include <luci/Importer.h>
20 #include <luci/CircleOptimizer.h>
21 #include <luci/Service/Validate.h>
22 #include <luci/CircleExporter.h>
23 #include <luci/CircleFileExpContract.h>
24
25 #include <oops/InternalExn.h>
26 #include <arser/arser.h>
27 #include <vconone/vconone.h>
28
29 #include <functional>
30 #include <iostream>
31 #include <map>
32 #include <string>
33
34 using OptionHook = std::function<int(const char **)>;
35
36 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
37 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
38
39 void print_exclusive_options(void)
40 {
41   std::cout << "Use only one of the 3 options below." << std::endl;
42   std::cout << "    --quantize_dequantize_weights" << std::endl;
43   std::cout << "    --quantize_with_minmax" << std::endl;
44   std::cout << "    --requantize" << std::endl;
45 }
46
47 void print_version(void)
48 {
49   std::cout << "circle-quantizer version " << vconone::get_string() << std::endl;
50   std::cout << vconone::get_copyright() << std::endl;
51 }
52
53 int entry(int argc, char **argv)
54 {
55   // Simple argument parser (based on map)
56   std::map<std::string, OptionHook> argparse;
57   luci::CircleOptimizer optimizer;
58
59   auto options = optimizer.options();
60
61   const std::string qdqw = "--quantize_dequantize_weights";
62   const std::string qwmm = "--quantize_with_minmax";
63   const std::string rq = "--requantize";
64
65   arser::Arser arser("circle-quantizer provides circle model quantization");
66
67   arser.add_argument("--version")
68       .nargs(0)
69       .required(false)
70       .default_value(false)
71       .help("Show version information and exit")
72       .exit_with(print_version);
73
74   arser.add_argument(qdqw)
75       .nargs(3)
76       .type(arser::DataType::STR_VEC)
77       .required(false)
78       .help("Quantize-dequantize weight values required action before quantization. "
79             "Three arguments required: input_dtype(float32) "
80             "output_dtype(uint8) granularity(layer, channel)");
81
82   arser.add_argument(qwmm)
83       .nargs(3)
84       .type(arser::DataType::STR_VEC)
85       .required(false)
86       .help("Quantize with min/max values. "
87             "Three arguments required: input_dtype(float32) "
88             "output_dtype(uint8) granularity(layer, channel)");
89
90   arser.add_argument(rq)
91       .nargs(2)
92       .type(arser::DataType::STR_VEC)
93       .required(false)
94       .help("Requantize a quantized model. "
95             "Two arguments required: input_dtype(int8) "
96             "output_dtype(uint8)");
97
98   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
99   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
100
101   try
102   {
103     arser.parse(argc, argv);
104   }
105   catch (const std::runtime_error &err)
106   {
107     std::cout << err.what() << std::endl;
108     std::cout << arser;
109     return 255;
110   }
111
112   if (arser[qdqw])
113   {
114     if (arser[qwmm] || arser[rq])
115     {
116       print_exclusive_options();
117       return 255;
118     }
119     auto values = arser.get<std::vector<std::string>>(qdqw);
120     if (values.size() != 3)
121     {
122       std::cerr << arser;
123       return 255;
124     }
125     options->enable(Algorithms::QuantizeDequantizeWeights);
126
127     options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
128     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
129     options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
130   }
131
132   if (arser[qwmm])
133   {
134     if (arser[qdqw] || arser[rq])
135     {
136       print_exclusive_options();
137       return 255;
138     }
139     auto values = arser.get<std::vector<std::string>>(qwmm);
140     if (values.size() != 3)
141     {
142       std::cerr << arser;
143       return 255;
144     }
145     options->enable(Algorithms::QuantizeWithMinMax);
146
147     options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
148     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
149     options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
150   }
151
152   if (arser[rq])
153   {
154     if (arser[qwmm] || arser[qdqw])
155     {
156       print_exclusive_options();
157       return 255;
158     }
159     auto values = arser.get<std::vector<std::string>>(rq);
160     if (values.size() != 2)
161     {
162       std::cerr << arser;
163       return 255;
164     }
165     options->enable(Algorithms::Requantize);
166
167     options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
168     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
169   }
170
171   std::string input_path = arser.get<std::string>("input");
172   std::string output_path = arser.get<std::string>("output");
173
174   // Load model from the file
175   foder::FileLoader file_loader{input_path};
176   std::vector<char> model_data = file_loader.load();
177
178   // Verify flatbuffers
179   flatbuffers::Verifier verifier{reinterpret_cast<uint8_t *>(model_data.data()), model_data.size()};
180   if (!circle::VerifyModelBuffer(verifier))
181   {
182     std::cerr << "ERROR: Invalid input file '" << input_path << "'" << std::endl;
183     return EXIT_FAILURE;
184   }
185
186   const circle::Model *circle_model = circle::GetModel(model_data.data());
187   if (circle_model == nullptr)
188   {
189     std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
190     return EXIT_FAILURE;
191   }
192
193   // Import from input Circle file
194   luci::Importer importer;
195   auto module = importer.importModule(circle_model);
196
197   for (size_t idx = 0; idx < module->size(); ++idx)
198   {
199     auto graph = module->graph(idx);
200
201     // quantize the graph
202     optimizer.quantize(graph);
203
204     if (!luci::validate(graph))
205     {
206       std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
207       return 255;
208     }
209   }
210
211   // Export to output Circle file
212   luci::CircleExporter exporter;
213
214   luci::CircleFileExpContract contract(module.get(), output_path);
215
216   if (!exporter.invoke(&contract))
217   {
218     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
219     return 255;
220   }
221
222   return 0;
223 }