8d3a80c914979e4a861e89d031ce04df65a3fc3f
[platform/core/ml/nnfw.git] / compiler / circle-quantizer / src / CircleQuantizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "CircleExpContract.h"
18
19 #include <foder/FileLoader.h>
20
21 #include <luci/Importer.h>
22 #include <luci/CircleOptimizer.h>
23 #include <luci/Service/Validate.h>
24 #include <luci/CircleExporter.h>
25
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28 #include <vconone/vconone.h>
29
30 #include <functional>
31 #include <iostream>
32 #include <map>
33 #include <string>
34
35 using OptionHook = std::function<int(const char **)>;
36
37 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
38 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
39
40 void print_version(void)
41 {
42   std::cout << "circle-quantizer version " << vconone::get_string() << std::endl;
43   std::cout << vconone::get_copyright() << std::endl;
44 }
45
46 int entry(int argc, char **argv)
47 {
48   // Simple argument parser (based on map)
49   std::map<std::string, OptionHook> argparse;
50   luci::CircleOptimizer optimizer;
51
52   auto options = optimizer.options();
53
54   const std::string qdqw = "--quantize_dequantize_weights";
55   const std::string qwmm = "--quantize_with_minmax";
56
57   arser::Arser arser("circle-quantizer provides circle model quantization");
58
59   arser.add_argument("--version")
60       .nargs(0)
61       .required(false)
62       .default_value(false)
63       .help("Show version information and exit")
64       .exit_with(print_version);
65
66   arser.add_argument(qdqw)
67       .nargs(3)
68       .type(arser::DataType::STR_VEC)
69       .required(false)
70       .help("Quantize-dequantize weight values required action before quantization. "
71             "Three arguments required: input_dtype(float32) "
72             "output_dtype(uint8) granularity(layer, channel)");
73
74   arser.add_argument(qwmm)
75       .nargs(3)
76       .type(arser::DataType::STR_VEC)
77       .required(false)
78       .help("Quantize with min/max values. "
79             "Three arguments required: input_dtype(float32) "
80             "output_dtype(uint8) granularity(layer, channel)");
81
82   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
83   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
84
85   try
86   {
87     arser.parse(argc, argv);
88   }
89   catch (const std::runtime_error &err)
90   {
91     std::cout << err.what() << std::endl;
92     std::cout << arser;
93     return 255;
94   }
95
96   if (arser[qdqw])
97   {
98     auto values = arser.get<std::vector<std::string>>(qdqw);
99     if (values.size() != 3)
100     {
101       std::cerr << arser;
102       return 255;
103     }
104     options->enable(Algorithms::QuantizeDequantizeWeights);
105
106     options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
107     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
108     options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
109   }
110
111   if (arser[qwmm])
112   {
113     auto values = arser.get<std::vector<std::string>>(qwmm);
114     if (values.size() != 3)
115     {
116       std::cerr << arser;
117       return 255;
118     }
119     options->enable(Algorithms::QuantizeWithMinMax);
120
121     options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
122     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
123     options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
124   }
125
126   std::string input_path = arser.get<std::string>("input");
127   std::string output_path = arser.get<std::string>("output");
128
129   // Load model from the file
130   foder::FileLoader file_loader{input_path};
131   std::vector<char> model_data = file_loader.load();
132   const circle::Model *circle_model = circle::GetModel(model_data.data());
133   if (circle_model == nullptr)
134   {
135     std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
136     return EXIT_FAILURE;
137   }
138
139   // Import from input Circle file
140   luci::Importer importer;
141   auto module = importer.importModule(circle_model);
142
143   for (size_t idx = 0; idx < module->size(); ++idx)
144   {
145     auto graph = module->graph(idx);
146
147     // quantize the graph
148     optimizer.quantize(graph);
149
150     if (!luci::validate(graph))
151     {
152       std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
153       return 255;
154     }
155   }
156
157   // Export to output Circle file
158   luci::CircleExporter exporter;
159
160   CircleExpContract contract(module.get(), output_path);
161
162   if (!exporter.invoke(&contract))
163   {
164     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
165     return 255;
166   }
167
168   return 0;
169 }