Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / record-minmax / driver / Driver.cpp
index d8ed95e..24a4ff8 100644 (file)
@@ -63,6 +63,15 @@ int entry(const int argc, char **argv)
     .type(arser::DataType::FLOAT)
     .help("Record n'th percentile of max");
 
+  arser.add_argument("--moving_avg_batch")
+    .type(arser::DataType::INT32)
+    .help("Batch size of moving average algorithm (default: 16)");
+
+  arser.add_argument("--moving_avg_const")
+    .type(arser::DataType::FLOAT)
+    .help("Hyperparameter (C) to compute moving average (default: 0.1). Update equation: avg <- "
+          "avg + C * (curr_batch_avg - avg)");
+
   arser.add_argument("--mode").help("Record mode. percentile (default) or moving_average");
 
   arser.add_argument("--input_data_format")
@@ -100,6 +109,8 @@ int entry(const int argc, char **argv)
   std::string mode("percentile");
   float min_percentile = 1.0;
   float max_percentile = 99.0;
+  uint32_t moving_avg_batch = 16;
+  float moving_avg_const = 0.1;
   std::string input_data_format("h5");
   uint32_t num_threads = 1;
 
@@ -118,6 +129,12 @@ int entry(const int argc, char **argv)
   if (arser["--mode"])
     mode = arser.get<std::string>("--mode");
 
+  if (arser["--moving_avg_batch"])
+    moving_avg_batch = arser.get<int>("--moving_avg_batch");
+
+  if (arser["--moving_avg_const"])
+    moving_avg_const = arser.get<float>("--moving_avg_const");
+
   if (mode != "percentile" && mode != "moving_average")
     throw std::runtime_error("Unsupported mode");
 
@@ -127,7 +144,23 @@ int entry(const int argc, char **argv)
   if (arser["--input_data_format"])
     input_data_format = arser.get<std::string>("--input_data_format");
 
-  RecordMinMax rmm(num_threads);
+  std::unique_ptr<MinMaxComputer> computer;
+  {
+    if (mode == "percentile")
+    {
+      computer = make_percentile_computer(min_percentile, max_percentile);
+    }
+    else if (mode == "moving_average")
+    {
+      computer = make_moving_avg_computer(moving_avg_batch, moving_avg_const);
+    }
+    else
+    {
+      assert(false);
+    }
+  }
+
+  RecordMinMax rmm(num_threads, std::move(computer));
 
   // TODO: support parallel record for profile with random data
   if (num_threads > 1 and not arser["--input_data"])
@@ -152,11 +185,11 @@ int entry(const int argc, char **argv)
     {
       // Profile min/max while executing the H5 data
       if (num_threads == 1)
-        rmm.profileData(mode, input_data_path, min_percentile, max_percentile);
+        rmm.profileData(input_data_path);
       else
       {
         INFO(l) << "Using parallel recording" << std::endl;
-        rmm.profileDataInParallel(mode, input_data_path, min_percentile, max_percentile);
+        rmm.profileDataInParallel(input_data_path);
       }
     }
     // input_data is a text file having a file path in each line.
@@ -170,13 +203,13 @@ int entry(const int argc, char **argv)
     else if (input_data_format == "list" || input_data_format == "filelist")
     {
       // Profile min/max while executing the list of Raw data
-      rmm.profileRawData(mode, input_data_path, min_percentile, max_percentile);
+      rmm.profileRawData(input_data_path);
     }
     else if (input_data_format == "directory" || input_data_format == "dir")
     {
       // Profile min/max while executing all files under the given directory
       // The contents of each file is same as the raw data in the 'list' type
-      rmm.profileRawDataDirectory(mode, input_data_path, min_percentile, max_percentile);
+      rmm.profileRawDataDirectory(input_data_path);
     }
     else
     {
@@ -187,7 +220,7 @@ int entry(const int argc, char **argv)
   else
   {
     // Profile min/max while executing random input data
-    rmm.profileDataWithRandomInputs(mode, min_percentile, max_percentile);
+    rmm.profileDataWithRandomInputs();
   }
 
   // Save profiled values to the model