Support basic benchmark (tflite_run) (#52)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Fri, 16 Mar 2018 09:41:21 +0000 (18:41 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Fri, 16 Mar 2018 09:41:21 +0000 (18:41 +0900)
This commit revises tflite_run to show elapsed time on prepare and
invoke step.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
tools/tflite_run/src/tflite_run.cc

index e4c78ee..6843edf 100644 (file)
@@ -4,10 +4,66 @@
 #include "operators.h"
 
 #include <iostream>
+#include <chrono>
 
 using namespace tflite;
 using namespace tflite::ops::builtin;
 
+// Benckmark support
+namespace benchmark
+{
+
+class ElapsedTime;
+
+class Stopwatch
+{
+public:
+  Stopwatch(ElapsedTime *buffer) : _buffer(buffer)
+  {
+    // DO NOTHING
+  }
+
+public:
+  ElapsedTime *buffer(void) { return _buffer; }
+
+private:
+  ElapsedTime *_buffer;
+};
+
+class ElapsedTime
+{
+public:
+  double count(void) const { return _elapsed.count(); }
+
+public:
+  ElapsedTime &add(const std::chrono::duration<double> &elapsed)
+  {
+    _elapsed += elapsed;
+    return (*this);
+  }
+
+public:
+  Stopwatch measure(void) { return Stopwatch(this); }
+
+private:
+  std::chrono::duration<double> _elapsed;
+};
+
+template <typename Callable> Stopwatch &operator<<(Stopwatch &&sw, Callable cb)
+{
+  using namespace std::chrono;
+
+  auto begin = steady_clock::now();
+  cb();
+  auto end = steady_clock::now();
+
+  sw.buffer()->add(duration_cast<duration<double>>(end - begin));
+
+  return sw;
+}
+
+} // namespace benchmark
+
 int main(int argc, char **argv)
 {
   const auto filename = argv[1];
@@ -23,32 +79,44 @@ int main(int argc, char **argv)
 
   auto model = FlatBufferModel::BuildFromFile(filename, &error_reporter);
 
-  BuiltinOpResolver resolver;
+  std::unique_ptr<Interpreter> interpreter;
+
+  TfLiteStatus status = kTfLiteError;
+
+  benchmark::ElapsedTime t_prepare;
+  benchmark::ElapsedTime t_invoke;
+
+  t_prepare.measure() << [&](void)
+  {
+    BuiltinOpResolver resolver;
 
 #define REGISTER(Name) { resolver.AddCustom(#Name, Register_##Name()); }
-  REGISTER(CAST);
-  REGISTER(Stack);
-  REGISTER(ArgMax);
-  REGISTER(TensorFlowMax);
+    REGISTER(CAST);
+    REGISTER(Stack);
+    REGISTER(ArgMax);
+    REGISTER(TensorFlowMax);
 #undef REGISTER
 
-  InterpreterBuilder builder(*model, resolver);
+    InterpreterBuilder builder(*model, resolver);
 
-  std::unique_ptr<Interpreter> interpreter;
-
-  TfLiteStatus status = kTfLiteError;
+    status = builder(&interpreter);
+    assert(status == kTfLiteOk);
 
-  status = builder(&interpreter);
-  assert(status == kTfLiteOk);
+    interpreter->UseNNAPI(use_nnapi);
+    interpreter->SetNumThreads(1);
 
-  interpreter->UseNNAPI(use_nnapi);
-  interpreter->SetNumThreads(1);
+    status = interpreter->AllocateTensors();
+    assert(status == kTfLiteOk);
+  };
 
-  status = interpreter->AllocateTensors();
-  assert(status == kTfLiteOk);
+  t_invoke.measure() << [&status, &interpreter](void)
+  {
+    status = interpreter->Invoke();
+    assert(status == kTfLiteOk);
+  };
 
-  status = interpreter->Invoke();
-  assert(status == kTfLiteOk);
+  std::cout << "Prepare takes " << t_prepare.count() << " seconds" << std::endl;
+  std::cout << "Invoke takes " << t_invoke.count() << " seconds" << std::endl;
 
   return status;
 }