.type(arser::DataType::FLOAT)
.help("Record n'th percentile of max");
+ arser.add_argument("--moving_avg_batch")
+ .type(arser::DataType::INT32)
+ .help("Batch size of moving average algorithm (default: 16)");
+
+ arser.add_argument("--moving_avg_const")
+ .type(arser::DataType::FLOAT)
+ .help("Hyperparameter (C) to compute moving average (default: 0.1). Update equation: avg <- "
+ "avg + C * (curr_batch_avg - avg)");
+
arser.add_argument("--mode").help("Record mode. percentile (default) or moving_average");
arser.add_argument("--input_data_format")
std::string mode("percentile");
float min_percentile = 1.0;
float max_percentile = 99.0;
+ uint32_t moving_avg_batch = 16;
+ float moving_avg_const = 0.1;
std::string input_data_format("h5");
uint32_t num_threads = 1;
if (arser["--mode"])
mode = arser.get<std::string>("--mode");
+ if (arser["--moving_avg_batch"])
+ moving_avg_batch = arser.get<int>("--moving_avg_batch");
+
+ if (arser["--moving_avg_const"])
+ moving_avg_const = arser.get<float>("--moving_avg_const");
+
if (mode != "percentile" && mode != "moving_average")
throw std::runtime_error("Unsupported mode");
if (arser["--input_data_format"])
input_data_format = arser.get<std::string>("--input_data_format");
- RecordMinMax rmm(num_threads);
+ std::unique_ptr<MinMaxComputer> computer;
+ {
+ if (mode == "percentile")
+ {
+ computer = make_percentile_computer(min_percentile, max_percentile);
+ }
+ else if (mode == "moving_average")
+ {
+ computer = make_moving_avg_computer(moving_avg_batch, moving_avg_const);
+ }
+ else
+ {
+ assert(false);
+ }
+ }
+
+ RecordMinMax rmm(num_threads, std::move(computer));
// TODO: support parallel record for profile with random data
if (num_threads > 1 and not arser["--input_data"])
{
// Profile min/max while executing the H5 data
if (num_threads == 1)
- rmm.profileData(mode, input_data_path, min_percentile, max_percentile);
+ rmm.profileData(input_data_path);
else
{
INFO(l) << "Using parallel recording" << std::endl;
- rmm.profileDataInParallel(mode, input_data_path, min_percentile, max_percentile);
+ rmm.profileDataInParallel(input_data_path);
}
}
// input_data is a text file having a file path in each line.
else if (input_data_format == "list" || input_data_format == "filelist")
{
// Profile min/max while executing the list of Raw data
- rmm.profileRawData(mode, input_data_path, min_percentile, max_percentile);
+ rmm.profileRawData(input_data_path);
}
else if (input_data_format == "directory" || input_data_format == "dir")
{
// Profile min/max while executing all files under the given directory
// The contents of each file is same as the raw data in the 'list' type
- rmm.profileRawDataDirectory(mode, input_data_path, min_percentile, max_percentile);
+ rmm.profileRawDataDirectory(input_data_path);
}
else
{
else
{
// Profile min/max while executing random input data
- rmm.profileDataWithRandomInputs(mode, min_percentile, max_percentile);
+ rmm.profileDataWithRandomInputs();
}
// Save profiled values to the model