Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / record-minmax / driver / Driver.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "RecordMinMax.h"
18
19 #include <arser/arser.h>
20 #include <vconone/vconone.h>
21
22 #include <luci/UserSettings.h>
23
24 // TODO declare own log signature of record-minmax
25 #include <luci/Log.h>
26
27 void print_version(void)
28 {
29   std::cout << "record-minmax version " << vconone::get_string() << std::endl;
30   std::cout << vconone::get_copyright() << std::endl;
31 }
32
33 int entry(const int argc, char **argv)
34 {
35   using namespace record_minmax;
36
37   LOGGER(l);
38
39   arser::Arser arser(
40     "Embedding min/max values of activations to the circle model for post-training quantization");
41
42   arser::Helper::add_version(arser, print_version);
43   arser::Helper::add_verbose(arser);
44
45   arser.add_argument("--input_model").required(true).help("Input model filepath");
46
47   arser.add_argument("--input_data")
48     .help("Input data filepath. If not given, record-minmax will run with randomly generated data. "
49           "Note that the random dataset does not represent inference workload, leading to poor "
50           "model accuracy.");
51
52   arser.add_argument("--output_model").required(true).help("Output model filepath");
53
54   arser.add_argument("--min_percentile")
55     .type(arser::DataType::FLOAT)
56     .help("Record n'th percentile of min");
57
58   arser.add_argument("--num_threads")
59     .type(arser::DataType::INT32)
60     .help("Number of threads (default: 1)");
61
62   arser.add_argument("--max_percentile")
63     .type(arser::DataType::FLOAT)
64     .help("Record n'th percentile of max");
65
66   arser.add_argument("--moving_avg_batch")
67     .type(arser::DataType::INT32)
68     .help("Batch size of moving average algorithm (default: 16)");
69
70   arser.add_argument("--moving_avg_const")
71     .type(arser::DataType::FLOAT)
72     .help("Hyperparameter (C) to compute moving average (default: 0.1). Update equation: avg <- "
73           "avg + C * (curr_batch_avg - avg)");
74
75   arser.add_argument("--mode").help("Record mode. percentile (default) or moving_average");
76
77   arser.add_argument("--input_data_format")
78     .help("Input data format. h5/hdf5 (default) or list/filelist");
79
80   arser.add_argument("--generate_profile_data")
81     .nargs(0)
82     .default_value(false)
83     .help("This will turn on profiling data generation.");
84
85   try
86   {
87     arser.parse(argc, argv);
88   }
89   catch (const std::runtime_error &err)
90   {
91     std::cout << err.what() << std::endl;
92     std::cout << arser;
93     return 255;
94   }
95
96   if (arser.get<bool>("--verbose"))
97   {
98     // The third parameter of setenv means REPLACE.
99     // If REPLACE is zero, it does not overwrite an existing value.
100     setenv("LUCI_LOG", "100", 0);
101   }
102
103   auto settings = luci::UserSettings::settings();
104
105   auto input_model_path = arser.get<std::string>("--input_model");
106   auto output_model_path = arser.get<std::string>("--output_model");
107
108   // Default values
109   std::string mode("percentile");
110   float min_percentile = 1.0;
111   float max_percentile = 99.0;
112   uint32_t moving_avg_batch = 16;
113   float moving_avg_const = 0.1;
114   std::string input_data_format("h5");
115   uint32_t num_threads = 1;
116
117   if (arser["--min_percentile"])
118     min_percentile = arser.get<float>("--min_percentile");
119
120   if (arser["--num_threads"])
121     num_threads = arser.get<int>("--num_threads");
122
123   if (num_threads < 1)
124     throw std::runtime_error("The number of threads must be greater than zero");
125
126   if (arser["--max_percentile"])
127     max_percentile = arser.get<float>("--max_percentile");
128
129   if (arser["--mode"])
130     mode = arser.get<std::string>("--mode");
131
132   if (arser["--moving_avg_batch"])
133     moving_avg_batch = arser.get<int>("--moving_avg_batch");
134
135   if (arser["--moving_avg_const"])
136     moving_avg_const = arser.get<float>("--moving_avg_const");
137
138   if (mode != "percentile" && mode != "moving_average")
139     throw std::runtime_error("Unsupported mode");
140
141   if (arser["--generate_profile_data"])
142     settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
143
144   if (arser["--input_data_format"])
145     input_data_format = arser.get<std::string>("--input_data_format");
146
147   std::unique_ptr<MinMaxComputer> computer;
148   {
149     if (mode == "percentile")
150     {
151       computer = make_percentile_computer(min_percentile, max_percentile);
152     }
153     else if (mode == "moving_average")
154     {
155       computer = make_moving_avg_computer(moving_avg_batch, moving_avg_const);
156     }
157     else
158     {
159       assert(false);
160     }
161   }
162
163   RecordMinMax rmm(num_threads, std::move(computer));
164
165   // TODO: support parallel record for profile with random data
166   if (num_threads > 1 and not arser["--input_data"])
167   {
168     throw std::runtime_error("Input data must be given for parallel recording");
169   }
170
171   // Initialize interpreter and observer
172   rmm.initialize(input_model_path);
173
174   if (arser["--input_data"])
175   {
176     auto input_data_path = arser.get<std::string>("--input_data");
177
178     // TODO: support parallel record from file and dir input data format
179     if (num_threads > 1 and not(input_data_format == "h5") and not(input_data_format == "hdf5"))
180     {
181       throw std::runtime_error("Parallel recording is used only for h5 now");
182     }
183
184     if (input_data_format == "h5" || input_data_format == "hdf5")
185     {
186       // Profile min/max while executing the H5 data
187       if (num_threads == 1)
188         rmm.profileData(input_data_path);
189       else
190       {
191         INFO(l) << "Using parallel recording" << std::endl;
192         rmm.profileDataInParallel(input_data_path);
193       }
194     }
195     // input_data is a text file having a file path in each line.
196     // Each data file is composed of inputs of a model, concatenated in
197     // the same order with the input index of the model
198     //
199     // For example, for a model with n inputs, the contents of each data
200     // file can be visualized as below
201     // [input 1][input 2]...[input n]
202     // |start............end of file|
203     else if (input_data_format == "list" || input_data_format == "filelist")
204     {
205       // Profile min/max while executing the list of Raw data
206       rmm.profileRawData(input_data_path);
207     }
208     else if (input_data_format == "directory" || input_data_format == "dir")
209     {
210       // Profile min/max while executing all files under the given directory
211       // The contents of each file is same as the raw data in the 'list' type
212       rmm.profileRawDataDirectory(input_data_path);
213     }
214     else
215     {
216       throw std::runtime_error(
217         "Unsupported input data format (supported formats: h5/hdf5 (default), list/filelist)");
218     }
219   }
220   else
221   {
222     // Profile min/max while executing random input data
223     rmm.profileDataWithRandomInputs();
224   }
225
226   // Save profiled values to the model
227   rmm.saveModel(output_model_path);
228
229   return EXIT_SUCCESS;
230 }