Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / tests / tools / tflite_run / src / args.cc
index fac2a6e..f8f581b 100644 (file)
@@ -37,6 +37,39 @@ Args::Args(const int argc, char **argv) noexcept
 
 void Args::Initialize(void)
 {
+  auto process_input = [&](const std::string &v) {
+    _input_filename = v;
+
+    if (!_input_filename.empty())
+    {
+      if (access(_input_filename.c_str(), F_OK) == -1)
+      {
+        std::cerr << "input image file not found: " << _input_filename << "\n";
+      }
+    }
+  };
+
+  auto process_tflite = [&](const std::string &v) {
+    _tflite_filename = v;
+
+    if (_tflite_filename.empty())
+    {
+      // TODO Print usage instead of the below message
+      std::cerr << "Please specify tflite file. Run with `--help` for usage."
+                << "\n";
+
+      exit(1);
+    }
+    else
+    {
+      if (access(_tflite_filename.c_str(), F_OK) == -1)
+      {
+        std::cerr << "tflite file not found: " << _tflite_filename << "\n";
+        exit(1);
+      }
+    }
+  };
+
   try
   {
     // General options
@@ -45,19 +78,19 @@ void Args::Initialize(void)
     // clang-format off
   general.add_options()
     ("help,h", "Display available options")
-    ("input,i", po::value<std::string>()->default_value(""), "Input filename")
-    ("dump,d", po::value<std::string>()->default_value(""), "Output filename")
-    ("ishapes", po::value<std::vector<int>>()->multitoken(), "Input shapes")
-    ("compare,c", po::value<std::string>()->default_value(""), "filename to be compared with")
-    ("tflite", po::value<std::string>()->required())
-    ("num_runs,r", po::value<int>()->default_value(1), "The number of runs")
-    ("warmup_runs,w", po::value<int>()->default_value(0), "The number of warmup runs")
-    ("run_delay,t", po::value<int>()->default_value(-1), "Delay time(ms) between runs (as default no delay")
-    ("gpumem_poll,g", po::value<bool>()->default_value(false), "Check gpu memory polling separately")
+    ("input,i", po::value<std::string>()->default_value("")->notifier(process_input), "Input filename")
+    ("dump,d", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _dump_filename = v; }), "Output filename")
+    ("ishapes", po::value<std::vector<int>>()->multitoken()->notifier([&](const auto &v) { _input_shapes = v; }), "Input shapes")
+    ("compare,c", po::value<std::string>()->default_value("")->notifier([&](const auto &v) { _compare_filename = v; }), "filename to be compared with")
+    ("tflite", po::value<std::string>()->required()->notifier(process_tflite))
+    ("num_runs,r", po::value<int>()->default_value(1)->notifier([&](const auto &v) { _num_runs = v; }), "The number of runs")
+    ("warmup_runs,w", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _warmup_runs = v; }), "The number of warmup runs")
+    ("run_delay,t", po::value<int>()->default_value(-1)->notifier([&](const auto &v) { _run_delay = v; }), "Delay time(ms) between runs (as default no delay)")
+    ("gpumem_poll,g", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _gpumem_poll = v; }), "Check gpu memory polling separately")
     ("mem_poll,m", po::value<bool>()->default_value(false), "Check memory polling")
-    ("write_report,p", po::value<bool>()->default_value(false), "Write report")
-    ("validate", po::value<bool>()->default_value(true), "Validate tflite model")
-    ("verbose_level,v", po::value<int>()->default_value(0), "Verbose level\n"
+    ("write_report,p", po::value<bool>()->default_value(false)->notifier([&](const auto &v) { _write_report = v; }), "Write report")
+    ("validate", po::value<bool>()->default_value(true)->notifier([&](const auto &v) { _tflite_validate = v; }), "Validate tflite model")
+    ("verbose_level,v", po::value<int>()->default_value(0)->notifier([&](const auto &v) { _verbose_level = v; }), "Verbose level\n"
          "0: prints the only result. Messages btw run don't print\n"
          "1: prints result and message btw run\n"
          "2: prints all of messages to print\n")
@@ -105,80 +138,7 @@ void Args::Parse(const int argc, char **argv)
 
   po::notify(vm);
 
-  if (vm.count("dump"))
-  {
-    _dump_filename = vm["dump"].as<std::string>();
-  }
-
-  if (vm.count("compare"))
-  {
-    _compare_filename = vm["compare"].as<std::string>();
-  }
-
-  if (vm.count("input"))
-  {
-    _input_filename = vm["input"].as<std::string>();
-
-    if (!_input_filename.empty())
-    {
-      if (access(_input_filename.c_str(), F_OK) == -1)
-      {
-        std::cerr << "input image file not found: " << _input_filename << "\n";
-      }
-    }
-  }
-
-  if (vm.count("ishapes"))
-  {
-    _input_shapes.resize(vm["ishapes"].as<std::vector<int>>().size());
-    for (auto i = 0; i < _input_shapes.size(); i++)
-    {
-      _input_shapes[i] = vm["ishapes"].as<std::vector<int>>()[i];
-    }
-  }
-
-  if (vm.count("tflite"))
-  {
-    _tflite_filename = vm["tflite"].as<std::string>();
-
-    if (_tflite_filename.empty())
-    {
-      // TODO Print usage instead of the below message
-      std::cerr << "Please specify tflite file. Run with `--help` for usage."
-                << "\n";
-
-      exit(1);
-    }
-    else
-    {
-      if (access(_tflite_filename.c_str(), F_OK) == -1)
-      {
-        std::cerr << "tflite file not found: " << _tflite_filename << "\n";
-        exit(1);
-      }
-    }
-  }
-
-  if (vm.count("num_runs"))
-  {
-    _num_runs = vm["num_runs"].as<int>();
-  }
-
-  if (vm.count("warmup_runs"))
-  {
-    _warmup_runs = vm["warmup_runs"].as<int>();
-  }
-
-  if (vm.count("run_delay"))
-  {
-    _run_delay = vm["run_delay"].as<int>();
-  }
-
-  if (vm.count("gpumem_poll"))
-  {
-    _gpumem_poll = vm["gpumem_poll"].as<bool>();
-  }
-
+  // This must be run after `notify` as `_warm_up_runs` must have been processed before.
   if (vm.count("mem_poll"))
   {
     _mem_poll = vm["mem_poll"].as<bool>();
@@ -188,21 +148,6 @@ void Args::Parse(const int argc, char **argv)
       _warmup_runs = 1;
     }
   }
-
-  if (vm.count("write_report"))
-  {
-    _write_report = vm["write_report"].as<bool>();
-  }
-
-  if (vm.count("validate"))
-  {
-    _tflite_validate = vm["validate"].as<bool>();
-  }
-
-  if (vm.count("verbose_level"))
-  {
-    _verbose_level = vm["verbose_level"].as<int>();
-  }
 }
 
 } // end of namespace TFLiteRun