[tflite_run] Support compare results (#1103)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Fri, 4 May 2018 06:19:54 +0000 (15:19 +0900)
committer김정현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh0822.kim@samsung.com>
Fri, 4 May 2018 06:19:54 +0000 (15:19 +0900)
With `--compare` or `-c` option, user can compare the results with output
tensor file which was dumped by `--dump` option.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
tools/tflite_run/src/args.cc
tools/tflite_run/src/args.h
tools/tflite_run/src/tflite_run.cc

index 0147c75..0b6a58a 100644 (file)
@@ -40,6 +40,7 @@ void Args::Initialize(void)
     ("help,h", "Display available options")
     ("input,i", po::value<std::string>(&_input_filename)->default_value(""), "Input filename")
     ("dump,d", po::value<std::string>()->default_value(""), "Output filename")
+    ("compare,c", po::value<std::string>()->default_value(""), "filename to be compared with")
     ("tflite", po::value<std::string>()->required());
   // clang-format on
 
@@ -82,6 +83,11 @@ void Args::Parse(const int argc, char **argv)
     _dump_filename = vm["dump"].as<std::string>();
   }
 
+  if (vm.count("compare"))
+  {
+    _compare_filename = vm["compare"].as<std::string>();
+  }
+
   if (vm.count("tflite"))
   {
     _tflite_filename = vm["tflite"].as<std::string>();
index 4142004..7b270d4 100644 (file)
@@ -34,6 +34,7 @@ public:
   const std::string &getTFLiteFilename(void) const { return _tflite_filename; }
   const std::string &getInputFilename(void) const { return _input_filename; }
   const std::string &getDumpFilename(void) const { return _dump_filename; }
+  const std::string &getCompareFilename(void) const { return _compare_filename; }
 
 private:
   void Initialize();
@@ -46,6 +47,7 @@ private:
   std::string _tflite_filename;
   std::string _input_filename;
   std::string _dump_filename;
+  std::string _compare_filename;
 };
 
 } // end of namespace TFLiteRun
index 1f378c6..77fb365 100644 (file)
 #include "bin_image.h"
 #include "args.h"
 #include "output_tensor_dumper.h"
+#include "output_tensor_loader.h"
+#include "util/environment.h"
+#include "util/fp32.h"
+#include "support/tflite/Diff.h"
 
 #include <iostream>
 #include <chrono>
@@ -193,5 +197,49 @@ int main(const int argc, char **argv)
   std::cout << "Prepare takes " << t_prepare.count() << " seconds" << std::endl;
   std::cout << "Invoke takes " << t_invoke.count() << " seconds" << std::endl;
 
+  if (!args.getCompareFilename().empty())
+  {
+    const std::string &compare_filename = args.getCompareFilename();
+    std::cout << "========================================" << std::endl;
+    std::cout << "Comparing the results with \"" << compare_filename << "\"." << std::endl;
+    std::cout << "========================================" << std::endl;
+
+    TFLiteRun::OutputTensorLoader output_loader(*interpreter);
+    output_loader.load(compare_filename);
+
+    // TODO Code duplication (copied from RandomTestRunner)
+
+    int tolerance = 1;
+    nnfw::util::env::IntAccessor("TOLERANCE").access(tolerance);
+
+    auto equals = [tolerance](float lhs, float rhs) {
+      // NOTE Hybrid approach
+      // TODO Allow users to set tolerance for absolute_epsilon_equal
+      if (nnfw::util::fp32::absolute_epsilon_equal(lhs, rhs))
+      {
+        return true;
+      }
+
+      return nnfw::util::fp32::epsilon_equal(lhs, rhs, tolerance);
+    };
+
+    TfLiteTensorComparator comparator(equals);
+    TfLiteInterpMatchApp app(comparator);
+    bool res = true;
+
+    for (const auto &o : interpreter->outputs())
+    {
+      auto expected = output_loader.get(o);
+      auto obtained = nnfw::support::tflite::TensorView<float>::make(*interpreter, o);
+
+      res = res && app.compareSingleTensorView(expected, obtained, o);
+    }
+
+    if (!res)
+    {
+      return 255;
+    }
+  }
+
   return status;
 }