From b5229f4d4703724d75bfe726409c2f418cdce695 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EA=B9=80=EC=9A=A9=EC=84=AD/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 4 Dec 2019 17:57:02 +0900 Subject: [PATCH] [tflite_run] Turn options memory polling and writing report off as default (#9377) Turn options memory polling and writing report off as default Signed-off-by: Yongseop Kim --- tests/tools/tflite_run/src/args.cc | 12 +++++ tests/tools/tflite_run/src/args.h | 4 ++ tests/tools/tflite_run/src/tflite_run.cc | 86 ++++++++++++++++++++------------ 3 files changed, 70 insertions(+), 32 deletions(-) diff --git a/tests/tools/tflite_run/src/args.cc b/tests/tools/tflite_run/src/args.cc index 8c1275f..4ec9466 100644 --- a/tests/tools/tflite_run/src/args.cc +++ b/tests/tools/tflite_run/src/args.cc @@ -46,6 +46,8 @@ void Args::Initialize(void) ("num_runs,r", po::value()->default_value(1), "The number of runs") ("warmup_runs,w", po::value()->default_value(0), "The number of warmup runs") ("gpumem_poll,g", po::value()->default_value(false), "Check gpu memory polling separately") + ("mem_poll,m", po::value()->default_value(false), "Check memory polling") + ("write_report,p", po::value()->default_value(false), "Write report") ; // clang-format on @@ -150,6 +152,16 @@ void Args::Parse(const int argc, char **argv) { _gpumem_poll = vm["gpumem_poll"].as(); } + + if (vm.count("mem_poll")) + { + _mem_poll = vm["mem_poll"].as(); + } + + if (vm.count("write_report")) + { + _write_report = vm["write_report"].as(); + } } } // end of namespace TFLiteRun diff --git a/tests/tools/tflite_run/src/args.h b/tests/tools/tflite_run/src/args.h index 00d1b74..8ba13b8 100644 --- a/tests/tools/tflite_run/src/args.h +++ b/tests/tools/tflite_run/src/args.h @@ -39,6 +39,8 @@ public: const int getNumRuns(void) const { return _num_runs; } const int getWarmupRuns(void) const { return _warmup_runs; } const bool getGpuMemoryPoll(void) const { return _gpumem_poll; } + const bool getMemoryPoll(void) const { return _mem_poll; } + const bool getWriteReport(void) const { return _write_report; } private: void Initialize(); @@ -56,6 +58,8 @@ private: int _num_runs; int _warmup_runs; bool _gpumem_poll; + bool _mem_poll; + bool _write_report; }; } // end of namespace TFLiteRun diff --git a/tests/tools/tflite_run/src/tflite_run.cc b/tests/tools/tflite_run/src/tflite_run.cc index 174896e..e77f255 100644 --- a/tests/tools/tflite_run/src/tflite_run.cc +++ b/tests/tools/tflite_run/src/tflite_run.cc @@ -121,8 +121,20 @@ int main(const int argc, char **argv) return 1; } - benchmark::MemoryPoller mp(std::chrono::milliseconds(5), args.getGpuMemoryPoll()); - std::vector mp_results; + std::unique_ptr mp{nullptr}; + if (args.getMemoryPoll()) + { + try + { + mp.reset(new benchmark::MemoryPoller(std::chrono::milliseconds(5), args.getGpuMemoryPoll())); + } + catch (const std::runtime_error &error) + { + std::cerr << error.what() << std::endl; + return 1; + } + } + std::vector mp_results({0, 0}); std::shared_ptr sess; @@ -135,9 +147,11 @@ int main(const int argc, char **argv) sess = std::make_shared(interpreter.get()); } - mp.Start("Compiling"); + if (mp) + mp->Start("Compiling"); sess->prepare(); - mp_results.emplace_back(mp.End("Compiling")); + if (mp) + mp_results[0] = mp->End("Compiling"); if (args.getInputShapes().size() != 0) { @@ -268,12 +282,14 @@ int main(const int argc, char **argv) std::cout << "]" << std::endl; // poll memories before warming up - mp.Start("Executing"); + if (mp) + mp->Start("Executing"); if (!sess->run()) { assert(0 && "run failed!"); } - mp_results.emplace_back(mp.End("Executing")); + if (mp) + mp_results[1] = mp->End("Executing"); // warmup runs for (uint32_t i = 1; i < args.getWarmupRuns(); i++) @@ -330,40 +346,46 @@ int main(const int argc, char **argv) std::cout << "- Max : " << acc.max() / 1e3 << "ms" << std::endl; std::cout << "- Mean: " << acc.mean() / 1e3 << "ms" << std::endl; - assert(mp_results.size() == 2); - std::cout << "===================================" << std::endl; - std::cout << "session_prepare takes " << mp_results[0] << " kb" << std::endl; - std::cout << "session_run takes " << mp_results[1] << " kb" << std::endl; - std::cout << "===================================" << std::endl; + if (mp) + { + assert(mp_results.size() == 2); + std::cout << "===================================" << std::endl; + std::cout << "session_prepare takes " << mp_results[0] << " kb" << std::endl; + std::cout << "session_run takes " << mp_results[1] << " kb" << std::endl; + std::cout << "===================================" << std::endl; + } } - // prepare csv task - std::string csv_filename; - std::string model_name; - std::string backend_name = default_backend_cand; + if (args.getWriteReport()) { - namespace fs = boost::filesystem; + // prepare csv task + std::string csv_filename; + std::string model_name; + std::string backend_name = default_backend_cand; + { + namespace fs = boost::filesystem; - fs::path model_path(args.getTFLiteFilename()); - model_name = model_path.stem().string(); + fs::path model_path(args.getTFLiteFilename()); + model_name = model_path.stem().string(); - fs::path exec_path(argv[0]); - std::string exec_name = exec_path.stem().string(); + fs::path exec_path(argv[0]); + std::string exec_name = exec_path.stem().string(); - csv_filename = exec_name + "-" + model_name + "-" + backend_name + ".csv"; - } + csv_filename = exec_name + "-" + model_name + "-" + backend_name + ".csv"; + } - // to csv - benchmark::CsvWriter writer(csv_filename); - writer << model_name << backend_name << acc.min() / 1e3 << acc.max() / 1e3 << acc.mean() / 1e3 - << mp_results[0] << mp_results[1]; - bool done = writer.Done(); + // to csv + benchmark::CsvWriter writer(csv_filename); + writer << model_name << backend_name << acc.min() / 1e3 << acc.max() / 1e3 << acc.mean() / 1e3 + << mp_results[0] << mp_results[1]; + bool done = writer.Done(); - std::cout << "Writing to " << csv_filename << " is "; - if (done) - std::cout << "done" << std::endl; - else - std::cout << "failed" << std::endl; + std::cout << "Writing to " << csv_filename << " is "; + if (done) + std::cout << "done" << std::endl; + else + std::cout << "failed" << std::endl; + } if (!args.getDumpFilename().empty()) { -- 2.7.4