Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / CircleMPQSolver.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <arser/arser.h>
18 #include <vconone/vconone.h>
19 #include <luci/CircleExporter.h>
20 #include <luci/CircleFileExpContract.h>
21
22 #include "bisection/BisectionSolver.h"
23 #include <core/SolverOutput.h>
24
25 #include <iostream>
26 #include <iomanip>
27
28 void print_version(void)
29 {
30   std::cout << "circle-mpqsolver version " << vconone::get_string() << std::endl;
31   std::cout << vconone::get_copyright() << std::endl;
32 }
33
34 int handleAutoAlgorithm(arser::Arser &arser, mpqsolver::bisection::BisectionSolver &solver)
35 {
36   solver.algorithm(mpqsolver::bisection::BisectionSolver::Algorithm::Auto);
37   auto data_path = arser.get<std::string>("--visq_file");
38   if (data_path.empty())
39   {
40     std::cerr << "ERROR: please provide visq_file for auto mode" << std::endl;
41     return false;
42   }
43   solver.setVisqPath(data_path);
44   return true;
45 }
46
47 int entry(int argc, char **argv)
48 {
49   const std::string bisection_str = "--bisection";
50   const std::string save_intermediate_str = "--save_intermediate";
51
52   arser::Arser arser("circle-mpqsolver provides light-weight methods for finding a high-quality "
53                      "mixed-precision model within a reasonable time.");
54
55   arser::Helper::add_version(arser, print_version);
56   arser::Helper::add_verbose(arser);
57
58   arser.add_argument("--data").required(true).help("Path to the test data");
59   arser.add_argument("--data_format").required(false).help("Test data format (default: h5)");
60
61   arser.add_argument("--qerror_ratio")
62     .type(arser::DataType::FLOAT)
63     .default_value(0.5f)
64     .help("quantization error ratio ([0, 1])");
65
66   arser.add_argument(bisection_str)
67     .nargs(1)
68     .type(arser::DataType::STR)
69     .help("Single optional argument for bisection method. "
70           "Whether input node should be quantized to Q16: 'auto', 'true', 'false'.");
71
72   arser.add_argument("--input_model")
73     .required(true)
74     .help("Input float model with min max initialized");
75
76   arser.add_argument("--input_dtype")
77     .type(arser::DataType::STR)
78     .default_value("uint8")
79     .help("Data type of quantized model's inputs (default: uint8)");
80
81   arser.add_argument("--output_dtype")
82     .type(arser::DataType::STR)
83     .default_value("uint8")
84     .help("Data type of quantized model's outputs (default: uint8)");
85
86   arser.add_argument("--output_model").required(true).help("Output quantized model");
87
88   arser.add_argument("--visq_file")
89     .type(arser::DataType::STR)
90     .default_value("")
91     .required(false)
92     .help("*.visq.json file with quantization errors");
93
94   arser.add_argument(save_intermediate_str)
95     .type(arser::DataType::STR)
96     .required(false)
97     .help("path to save intermediate results");
98
99   try
100   {
101     arser.parse(argc, argv);
102   }
103   catch (const std::runtime_error &err)
104   {
105     std::cerr << err.what() << std::endl;
106     std::cout << arser;
107     return EXIT_FAILURE;
108   }
109
110   if (arser.get<bool>("--verbose"))
111   {
112     // The third parameter of setenv means REPLACE.
113     // If REPLACE is zero, it does not overwrite an existing value.
114     setenv("LUCI_LOG", "100", 0);
115   }
116
117   auto data_path = arser.get<std::string>("--data");
118   auto input_model_path = arser.get<std::string>("--input_model");
119   auto output_model_path = arser.get<std::string>("--output_model");
120   auto input_dtype = arser.get<std::string>("--input_dtype");
121   auto output_dtype = arser.get<std::string>("--output_dtype");
122
123   float qerror_ratio = arser.get<float>("--qerror_ratio");
124   if (qerror_ratio < 0.0 || qerror_ratio > 1.f)
125   {
126     std::cerr << "ERROR: quantization ratio must be in [0, 1]" << std::endl;
127     return EXIT_FAILURE;
128   }
129
130   SolverOutput::get() << ">> Searching mixed precision configuration \n"
131                       << "model:" << input_model_path << "\n"
132                       << "dataset: " << data_path << "\n"
133                       << "input dtype: " << input_dtype << "\n"
134                       << "output dtype: " << output_dtype << "\n";
135
136   if (arser[bisection_str])
137   {
138     // optimize
139     using namespace mpqsolver::bisection;
140
141     BisectionSolver solver(data_path, qerror_ratio, input_dtype, output_dtype);
142     {
143       auto value = arser.get<std::string>(bisection_str);
144       if (value == "auto")
145       {
146         SolverOutput::get() << "algorithm: bisection (auto)\n";
147         if (!handleAutoAlgorithm(arser, solver))
148         {
149           return EXIT_FAILURE;
150         }
151       }
152       else if (value == "true")
153       {
154         SolverOutput::get() << "algorithm: bisection (Q16AtFront)";
155         solver.algorithm(BisectionSolver::Algorithm::ForceQ16Front);
156       }
157       else if (value == "false")
158       {
159         SolverOutput::get() << "algorithm: bisection (Q8AtFront)";
160         solver.algorithm(BisectionSolver::Algorithm::ForceQ16Back);
161       }
162       else
163       {
164         std::cerr << "ERROR: Unrecognized option for bisection algortithm" << input_model_path
165                   << std::endl;
166         return EXIT_FAILURE;
167       }
168     }
169
170     if (arser[save_intermediate_str])
171     {
172       auto data_path = arser.get<std::string>(save_intermediate_str);
173       if (!data_path.empty())
174       {
175         solver.set_save_intermediate(data_path);
176       }
177     }
178
179     SolverOutput::get() << "qerror metric: MAE\n"
180                         << "target qerror ratio: " << qerror_ratio << "\n";
181
182     auto optimized = solver.run(input_model_path);
183     if (optimized == nullptr)
184     {
185       std::cerr << "ERROR: Failed to build mixed precision model" << input_model_path << std::endl;
186       return EXIT_FAILURE;
187     }
188
189     // save optimized
190     {
191       SolverOutput::get() << "Saving output model to " << output_model_path << "\n";
192       luci::CircleExporter exporter;
193       luci::CircleFileExpContract contract(optimized.get(), output_model_path);
194       if (!exporter.invoke(&contract))
195       {
196         std::cerr << "ERROR: Failed to export mixed precision model" << input_model_path
197                   << std::endl;
198         return EXIT_FAILURE;
199       }
200     }
201   }
202   else
203   {
204     std::cerr << "ERROR: Unrecognized solver" << std::endl;
205     return EXIT_FAILURE;
206   }
207
208   return EXIT_SUCCESS;
209 }