23e8fd4ced612e1820b3bf07790eca5ea3e83f49
[platform/core/ml/nnfw.git] / compiler / circle-mpqsolver / src / CircleMPQSolver.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <arser/arser.h>
18 #include <vconone/vconone.h>
19 #include <luci/CircleExporter.h>
20 #include <luci/CircleFileExpContract.h>
21 #include <luci/Log.h>
22
23 #include "bisection/BisectionSolver.h"
24
25 #include <iostream>
26 #include <iomanip>
27 #include <chrono>
28
29 void print_version(void)
30 {
31   std::cout << "circle-mpqsolver version " << vconone::get_string() << std::endl;
32   std::cout << vconone::get_copyright() << std::endl;
33 }
34
35 int entry(int argc, char **argv)
36 {
37   LOGGER(l);
38
39   const std::string bisection_str = "--bisection";
40
41   arser::Arser arser("circle-mpqsolver provides light-weight methods for finding a high-quality "
42                      "mixed-precision model within a reasonable time.");
43
44   arser::Helper::add_version(arser, print_version);
45   arser::Helper::add_verbose(arser);
46
47   arser.add_argument("--data").required(true).help("Path to the test data");
48   arser.add_argument("--data_format").required(false).help("Test data format (default: h5)");
49
50   arser.add_argument("--qerror_ratio")
51     .type(arser::DataType::FLOAT)
52     .default_value(0.5f)
53     .help("quantization error ratio ([0, 1])");
54
55   arser.add_argument(bisection_str)
56     .nargs(1)
57     .type(arser::DataType::STR)
58     .help("Single optional argument for bisection method. "
59           "Whether input node should be quantized to Q16: 'auto', 'true', 'false'.");
60
61   arser.add_argument("--input_model")
62     .required(true)
63     .help("Input float model with min max initialized");
64
65   arser.add_argument("--input_dtype")
66     .type(arser::DataType::STR)
67     .default_value("uint8")
68     .help("Data type of quantized model's inputs (default: uint8)");
69
70   arser.add_argument("--output_dtype")
71     .type(arser::DataType::STR)
72     .default_value("uint8")
73     .help("Data type of quantized model's outputs (default: uint8)");
74
75   arser.add_argument("--output_model").required(true).help("Output quantized model");
76
77   try
78   {
79     arser.parse(argc, argv);
80   }
81   catch (const std::runtime_error &err)
82   {
83     std::cerr << err.what() << std::endl;
84     std::cout << arser;
85     return EXIT_FAILURE;
86   }
87
88   if (arser.get<bool>("--verbose"))
89   {
90     // The third parameter of setenv means REPLACE.
91     // If REPLACE is zero, it does not overwrite an existing value.
92     setenv("LUCI_LOG", "100", 0);
93   }
94
95   auto data_path = arser.get<std::string>("--data");
96   auto input_model_path = arser.get<std::string>("--input_model");
97   auto output_model_path = arser.get<std::string>("--output_model");
98   auto input_dtype = arser.get<std::string>("--input_dtype");
99   auto output_dtype = arser.get<std::string>("--output_dtype");
100
101   float qerror_ratio = arser.get<float>("--qerror_ratio");
102   if (qerror_ratio < 0.0 || qerror_ratio > 1.f)
103   {
104     std::cerr << "ERROR: quantization ratio must be in [0, 1]" << std::endl;
105     return EXIT_FAILURE;
106   }
107   auto start = std::chrono::high_resolution_clock::now();
108
109   if (arser[bisection_str])
110   {
111     // optimize
112     using namespace mpqsolver::bisection;
113
114     BisectionSolver solver(data_path, qerror_ratio, input_dtype, output_dtype);
115     {
116       auto value = arser.get<std::string>(bisection_str);
117       if (value == "auto")
118       {
119         solver.algorithm(BisectionSolver::Algorithm::Auto);
120       }
121       else if (value == "true")
122       {
123         solver.algorithm(BisectionSolver::Algorithm::ForceQ16Front);
124       }
125       else if (value == "false")
126       {
127         solver.algorithm(BisectionSolver::Algorithm::ForceQ16Back);
128       }
129       else
130       {
131         std::cerr << "ERROR: Unrecognized option for bisection algortithm" << input_model_path
132                   << std::endl;
133         return EXIT_FAILURE;
134       }
135     }
136
137     auto optimized = solver.run(input_model_path);
138     if (optimized == nullptr)
139     {
140       std::cerr << "ERROR: Failed to build mixed precision model" << input_model_path << std::endl;
141       return EXIT_FAILURE;
142     }
143
144     // save optimized
145     {
146       luci::CircleExporter exporter;
147       luci::CircleFileExpContract contract(optimized.get(), output_model_path);
148       if (!exporter.invoke(&contract))
149       {
150         std::cerr << "ERROR: Failed to export mixed precision model" << input_model_path
151                   << std::endl;
152         return EXIT_FAILURE;
153       }
154     }
155   }
156   else
157   {
158     std::cerr << "ERROR: Unrecognized solver" << std::endl;
159     return EXIT_FAILURE;
160   }
161
162   auto duration = std::chrono::duration_cast<std::chrono::seconds>(
163     std::chrono::high_resolution_clock::now() - start);
164   VERBOSE(l, 0) << "Elapsed Time: " << std::setprecision(5) << duration.count() / 60.f
165                 << " minutes." << std::endl;
166
167   return EXIT_SUCCESS;
168 }