Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / circle-quantizer / src / CircleQuantizer.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "CircleExpContract.h"
18
19 #include <foder/FileLoader.h>
20
21 #include <luci/Importer.h>
22 #include <luci/CircleOptimizer.h>
23 #include <luci/Service/Validate.h>
24 #include <luci/CircleExporter.h>
25
26 #include <oops/InternalExn.h>
27 #include <arser/arser.h>
28
29 #include <functional>
30 #include <iostream>
31 #include <map>
32 #include <string>
33
34 using OptionHook = std::function<int(const char **)>;
35
36 using Algorithms = luci::CircleOptimizer::Options::Algorithm;
37 using AlgorithmParameters = luci::CircleOptimizer::Options::AlgorithmParameters;
38
39 int entry(int argc, char **argv)
40 {
41   // Simple argument parser (based on map)
42   std::map<std::string, OptionHook> argparse;
43   luci::CircleOptimizer optimizer;
44
45   auto options = optimizer.options();
46
47   const std::string qdqw = "--quantize_dequantize_weights";
48   const std::string qwmm = "--quantize_with_minmax";
49
50   arser::Arser arser("circle-quantizer provides circle model quantization");
51
52   arser.add_argument(qdqw)
53       .nargs(3)
54       .type(arser::DataType::STR_VEC)
55       .required(false)
56       .help("Quantize-dequantize weight values required action before quantization. "
57             "Three arguments required: input_dtype(float32) "
58             "output_dtype(uint8) granularity(layer)");
59
60   arser.add_argument(qwmm)
61       .nargs(3)
62       .type(arser::DataType::STR_VEC)
63       .required(false)
64       .help("Quantize with min/max values. "
65             "Three arguments required: input_dtype(float32) "
66             "output_dtype(uint8) granularity(layer)");
67
68   arser.add_argument("input").nargs(1).type(arser::DataType::STR).help("Input circle model");
69   arser.add_argument("output").nargs(1).type(arser::DataType::STR).help("Output circle model");
70
71   try
72   {
73     arser.parse(argc, argv);
74   }
75   catch (const std::runtime_error &err)
76   {
77     std::cout << err.what() << std::endl;
78     std::cout << arser;
79     return 255;
80   }
81
82   if (arser[qdqw])
83   {
84     auto values = arser.get<std::vector<std::string>>(qdqw);
85     if (values.size() != 3)
86     {
87       std::cerr << arser;
88       return 255;
89     }
90     options->enable(Algorithms::QuantizeDequantizeWeights);
91
92     options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
93     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
94     options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
95   }
96
97   if (arser[qwmm])
98   {
99     auto values = arser.get<std::vector<std::string>>(qwmm);
100     if (values.size() != 3)
101     {
102       std::cerr << arser;
103       return 255;
104     }
105     options->enable(Algorithms::QuantizeWithMinMax);
106
107     options->param(AlgorithmParameters::Quantize_input_dtype, values.at(0));
108     options->param(AlgorithmParameters::Quantize_output_dtype, values.at(1));
109     options->param(AlgorithmParameters::Quantize_granularity, values.at(2));
110   }
111
112   std::string input_path = arser.get<std::string>("input");
113   std::string output_path = arser.get<std::string>("output");
114
115   // Load model from the file
116   foder::FileLoader file_loader{input_path};
117   std::vector<char> model_data = file_loader.load();
118   const circle::Model *circle_model = circle::GetModel(model_data.data());
119   if (circle_model == nullptr)
120   {
121     std::cerr << "ERROR: Failed to load circle '" << input_path << "'" << std::endl;
122     return EXIT_FAILURE;
123   }
124
125   // Import from input Circle file
126   luci::Importer importer;
127   auto module = importer.importModule(circle_model);
128
129   for (size_t idx = 0; idx < module->size(); ++idx)
130   {
131     auto graph = module->graph(idx);
132
133     // quantize the graph
134     optimizer.quantize(graph);
135
136     if (!luci::validate(graph))
137     {
138       std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
139       return 255;
140     }
141   }
142
143   // Export to output Circle file
144   luci::CircleExporter exporter;
145
146   CircleExpContract contract(module.get(), output_path);
147
148   if (!exporter.invoke(&contract))
149   {
150     std::cerr << "ERROR: Failed to export '" << output_path << "'" << std::endl;
151     return 255;
152   }
153
154   return 0;
155 }