Introduce nnapi_test tool (#290)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 29 Mar 2018 03:12:27 +0000 (12:12 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Thu, 29 Mar 2018 03:12:27 +0000 (12:12 +0900)
* Introduce nnapi_test tool

This commit introduces nnapi_test tool, which runs T/F Lite interpreter
with and without NNAPI, and compare their result.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
tools/CMakeLists.txt
tools/nnapi_test/CMakeLists.txt [new file with mode: 0644]
tools/nnapi_test/src/nnapi_test.cc [new file with mode: 0644]

index ea8370f..171b2ee 100644 (file)
@@ -4,3 +4,4 @@ if(ROOTFS_ARM STREQUAL "")
 endif()
 add_subdirectory(tflite_run)
 add_subdirectory(nnapi_bindings)
+add_subdirectory(nnapi_test)
diff --git a/tools/nnapi_test/CMakeLists.txt b/tools/nnapi_test/CMakeLists.txt
new file mode 100644 (file)
index 0000000..8382404
--- /dev/null
@@ -0,0 +1,4 @@
+list(APPEND SOURCES "src/nnapi_test.cc")
+
+add_executable(nnapi_test ${SOURCES})
+target_link_libraries(nnapi_test tensorflow_lite)
diff --git a/tools/nnapi_test/src/nnapi_test.cc b/tools/nnapi_test/src/nnapi_test.cc
new file mode 100644 (file)
index 0000000..fd3fd64
--- /dev/null
@@ -0,0 +1,121 @@
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+#include <iostream>
+#include <chrono>
+#include <algorithm>
+
+using namespace tflite;
+using namespace tflite::ops::builtin;
+
+inline void check(const TfLiteStatus &status) { assert(status != kTfLiteError); }
+
+std::unique_ptr<Interpreter> build_interpreter(const FlatBufferModel &model, bool use_nnapi)
+{
+  std::unique_ptr<Interpreter> interpreter;
+
+  BuiltinOpResolver resolver;
+
+  InterpreterBuilder builder(model, resolver);
+
+  check(builder(&interpreter));
+
+  interpreter->UseNNAPI(use_nnapi);
+
+  return std::move(interpreter);
+}
+
+void initialize_interpreter(Interpreter &interpreter)
+{
+  check(interpreter.AllocateTensors());
+
+  // TODO Find a better way to initialize tensors
+  for (const auto &id : interpreter.inputs())
+  {
+    auto tensor = interpreter.tensor(id);
+    auto ptr = tensor->data.uint8;
+    auto len = tensor->bytes;
+
+    for (size_t ind = 0; ind < len; ++ind)
+    {
+      ptr[ind] = ind;
+    }
+  }
+}
+
+void invoke_interpreter(Interpreter &interpreter)
+{
+  check(interpreter.Invoke());
+}
+
+template <typename T> bool operator==(const std::vector<T> &lhs, const std::vector<T> &rhs)
+{
+  if (lhs.size() != rhs.size())
+  {
+    return false;
+  }
+
+  for (size_t ind = 0; ind < lhs.size(); ++ind)
+  {
+    if (lhs.at(ind) != rhs.at(ind))
+    {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool operator==(const TfLiteTensor &lhs, const TfLiteTensor &rhs)
+{
+  if (lhs.bytes != rhs.bytes)
+  {
+    return false;
+  }
+
+  for (size_t off = 0; off < lhs.bytes; ++off)
+  {
+    if (lhs.data.uint8[off] != rhs.data.uint8[off])
+    {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+int main(const int argc, char **argv)
+{
+  const auto filename = argv[1];
+
+  StderrReporter error_reporter;
+
+  auto model = FlatBufferModel::BuildFromFile(filename, &error_reporter);
+
+  std::cout << "[NNAPI TEST] Run T/F Lite Interpreter without NNAPI" << std::endl;
+
+  std::unique_ptr<Interpreter> pure = build_interpreter(*model, false);
+  initialize_interpreter(*pure);
+  invoke_interpreter(*pure);
+
+  std::cout << "[NNAPI TEST] Run T/F Lite Interpreter with NNAPI" << std::endl;
+
+  std::unique_ptr<Interpreter> delegated = build_interpreter(*model, true);
+  initialize_interpreter(*delegated);
+  invoke_interpreter(*delegated);
+
+  std::cout << "[NNAPI TEST] Compare the result" << std::endl;
+
+  assert(pure->inputs() == delegated->inputs());
+  assert(pure->outputs() == delegated->outputs());
+
+  for (const auto &id : pure->outputs())
+  {
+    std::cout << "  Compare tensor #" << id << std::endl;
+    assert(*(pure->tensor(id)) == *(delegated->tensor(id)));
+  }
+
+  std::cout << "[NNAPI TEST] PASSED" << std::endl;
+
+  return 0;
+}