2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "RecordMinMax.h"
19 #include <arser/arser.h>
20 #include <vconone/vconone.h>
22 #include <luci/UserSettings.h>
24 // TODO declare own log signature of record-minmax
27 void print_version(void)
29 std::cout << "record-minmax version " << vconone::get_string() << std::endl;
30 std::cout << vconone::get_copyright() << std::endl;
33 int entry(const int argc, char **argv)
35 using namespace record_minmax;
40 "Embedding min/max values of activations to the circle model for post-training quantization");
42 arser::Helper::add_version(arser, print_version);
43 arser::Helper::add_verbose(arser);
45 arser.add_argument("--input_model").required(true).help("Input model filepath");
47 arser.add_argument("--input_data")
48 .help("Input data filepath. If not given, record-minmax will run with randomly generated data. "
49 "Note that the random dataset does not represent inference workload, leading to poor "
52 arser.add_argument("--output_model").required(true).help("Output model filepath");
54 arser.add_argument("--min_percentile")
55 .type(arser::DataType::FLOAT)
56 .help("Record n'th percentile of min");
58 arser.add_argument("--num_threads")
59 .type(arser::DataType::INT32)
60 .help("Number of threads (default: 1)");
62 arser.add_argument("--max_percentile")
63 .type(arser::DataType::FLOAT)
64 .help("Record n'th percentile of max");
66 arser.add_argument("--moving_avg_batch")
67 .type(arser::DataType::INT32)
68 .help("Batch size of moving average algorithm (default: 16)");
70 arser.add_argument("--moving_avg_const")
71 .type(arser::DataType::FLOAT)
72 .help("Hyperparameter (C) to compute moving average (default: 0.1). Update equation: avg <- "
73 "avg + C * (curr_batch_avg - avg)");
75 arser.add_argument("--mode").help("Record mode. percentile (default) or moving_average");
77 arser.add_argument("--input_data_format")
78 .help("Input data format. h5/hdf5 (default) or list/filelist");
80 arser.add_argument("--generate_profile_data")
83 .help("This will turn on profiling data generation.");
87 arser.parse(argc, argv);
89 catch (const std::runtime_error &err)
91 std::cout << err.what() << std::endl;
96 if (arser.get<bool>("--verbose"))
98 // The third parameter of setenv means REPLACE.
99 // If REPLACE is zero, it does not overwrite an existing value.
100 setenv("LUCI_LOG", "100", 0);
103 auto settings = luci::UserSettings::settings();
105 auto input_model_path = arser.get<std::string>("--input_model");
106 auto output_model_path = arser.get<std::string>("--output_model");
109 std::string mode("percentile");
110 float min_percentile = 1.0;
111 float max_percentile = 99.0;
112 uint32_t moving_avg_batch = 16;
113 float moving_avg_const = 0.1;
114 std::string input_data_format("h5");
115 uint32_t num_threads = 1;
117 if (arser["--min_percentile"])
118 min_percentile = arser.get<float>("--min_percentile");
120 if (arser["--num_threads"])
121 num_threads = arser.get<int>("--num_threads");
124 throw std::runtime_error("The number of threads must be greater than zero");
126 if (arser["--max_percentile"])
127 max_percentile = arser.get<float>("--max_percentile");
130 mode = arser.get<std::string>("--mode");
132 if (arser["--moving_avg_batch"])
133 moving_avg_batch = arser.get<int>("--moving_avg_batch");
135 if (arser["--moving_avg_const"])
136 moving_avg_const = arser.get<float>("--moving_avg_const");
138 if (mode != "percentile" && mode != "moving_average")
139 throw std::runtime_error("Unsupported mode");
141 if (arser["--generate_profile_data"])
142 settings->set(luci::UserSettings::Key::ProfilingDataGen, true);
144 if (arser["--input_data_format"])
145 input_data_format = arser.get<std::string>("--input_data_format");
147 std::unique_ptr<MinMaxComputer> computer;
149 if (mode == "percentile")
151 computer = make_percentile_computer(min_percentile, max_percentile);
153 else if (mode == "moving_average")
155 computer = make_moving_avg_computer(moving_avg_batch, moving_avg_const);
163 RecordMinMax rmm(num_threads, std::move(computer));
165 // TODO: support parallel record for profile with random data
166 if (num_threads > 1 and not arser["--input_data"])
168 throw std::runtime_error("Input data must be given for parallel recording");
171 // Initialize interpreter and observer
172 rmm.initialize(input_model_path);
174 if (arser["--input_data"])
176 auto input_data_path = arser.get<std::string>("--input_data");
178 // TODO: support parallel record from file and dir input data format
179 if (num_threads > 1 and not(input_data_format == "h5") and not(input_data_format == "hdf5"))
181 throw std::runtime_error("Parallel recording is used only for h5 now");
184 if (input_data_format == "h5" || input_data_format == "hdf5")
186 // Profile min/max while executing the H5 data
187 if (num_threads == 1)
188 rmm.profileData(input_data_path);
191 INFO(l) << "Using parallel recording" << std::endl;
192 rmm.profileDataInParallel(input_data_path);
195 // input_data is a text file having a file path in each line.
196 // Each data file is composed of inputs of a model, concatenated in
197 // the same order with the input index of the model
199 // For example, for a model with n inputs, the contents of each data
200 // file can be visualized as below
201 // [input 1][input 2]...[input n]
202 // |start............end of file|
203 else if (input_data_format == "list" || input_data_format == "filelist")
205 // Profile min/max while executing the list of Raw data
206 rmm.profileRawData(input_data_path);
208 else if (input_data_format == "directory" || input_data_format == "dir")
210 // Profile min/max while executing all files under the given directory
211 // The contents of each file is same as the raw data in the 'list' type
212 rmm.profileRawDataDirectory(input_data_path);
216 throw std::runtime_error(
217 "Unsupported input data format (supported formats: h5/hdf5 (default), list/filelist)");
222 // Profile min/max while executing random input data
223 rmm.profileDataWithRandomInputs();
226 // Save profiled values to the model
227 rmm.saveModel(output_model_path);