From 3f1e1042f7e08b87088363aa16ea47d7c5bcb0ef Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=ED=95=9C=EC=A2=85/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Fri, 4 May 2018 15:19:54 +0900 Subject: [PATCH] [tflite_run] Support compare results (#1103) With `--compare` or `-c` option, user can compare the results with output tensor file which was dumped by `--dump` option. Signed-off-by: Hanjoung Lee --- tools/tflite_run/src/args.cc | 6 +++++ tools/tflite_run/src/args.h | 2 ++ tools/tflite_run/src/tflite_run.cc | 48 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/tools/tflite_run/src/args.cc b/tools/tflite_run/src/args.cc index 0147c75..0b6a58a 100644 --- a/tools/tflite_run/src/args.cc +++ b/tools/tflite_run/src/args.cc @@ -40,6 +40,7 @@ void Args::Initialize(void) ("help,h", "Display available options") ("input,i", po::value(&_input_filename)->default_value(""), "Input filename") ("dump,d", po::value()->default_value(""), "Output filename") + ("compare,c", po::value()->default_value(""), "filename to be compared with") ("tflite", po::value()->required()); // clang-format on @@ -82,6 +83,11 @@ void Args::Parse(const int argc, char **argv) _dump_filename = vm["dump"].as(); } + if (vm.count("compare")) + { + _compare_filename = vm["compare"].as(); + } + if (vm.count("tflite")) { _tflite_filename = vm["tflite"].as(); diff --git a/tools/tflite_run/src/args.h b/tools/tflite_run/src/args.h index 4142004..7b270d4 100644 --- a/tools/tflite_run/src/args.h +++ b/tools/tflite_run/src/args.h @@ -34,6 +34,7 @@ public: const std::string &getTFLiteFilename(void) const { return _tflite_filename; } const std::string &getInputFilename(void) const { return _input_filename; } const std::string &getDumpFilename(void) const { return _dump_filename; } + const std::string &getCompareFilename(void) const { return _compare_filename; } private: void Initialize(); @@ -46,6 +47,7 @@ private: std::string _tflite_filename; std::string _input_filename; std::string _dump_filename; + std::string _compare_filename; }; } // end of namespace TFLiteRun diff --git a/tools/tflite_run/src/tflite_run.cc b/tools/tflite_run/src/tflite_run.cc index 1f378c6..77fb365 100644 --- a/tools/tflite_run/src/tflite_run.cc +++ b/tools/tflite_run/src/tflite_run.cc @@ -21,6 +21,10 @@ #include "bin_image.h" #include "args.h" #include "output_tensor_dumper.h" +#include "output_tensor_loader.h" +#include "util/environment.h" +#include "util/fp32.h" +#include "support/tflite/Diff.h" #include #include @@ -193,5 +197,49 @@ int main(const int argc, char **argv) std::cout << "Prepare takes " << t_prepare.count() << " seconds" << std::endl; std::cout << "Invoke takes " << t_invoke.count() << " seconds" << std::endl; + if (!args.getCompareFilename().empty()) + { + const std::string &compare_filename = args.getCompareFilename(); + std::cout << "========================================" << std::endl; + std::cout << "Comparing the results with \"" << compare_filename << "\"." << std::endl; + std::cout << "========================================" << std::endl; + + TFLiteRun::OutputTensorLoader output_loader(*interpreter); + output_loader.load(compare_filename); + + // TODO Code duplication (copied from RandomTestRunner) + + int tolerance = 1; + nnfw::util::env::IntAccessor("TOLERANCE").access(tolerance); + + auto equals = [tolerance](float lhs, float rhs) { + // NOTE Hybrid approach + // TODO Allow users to set tolerance for absolute_epsilon_equal + if (nnfw::util::fp32::absolute_epsilon_equal(lhs, rhs)) + { + return true; + } + + return nnfw::util::fp32::epsilon_equal(lhs, rhs, tolerance); + }; + + TfLiteTensorComparator comparator(equals); + TfLiteInterpMatchApp app(comparator); + bool res = true; + + for (const auto &o : interpreter->outputs()) + { + auto expected = output_loader.get(o); + auto obtained = nnfw::support::tflite::TensorView::make(*interpreter, o); + + res = res && app.compareSingleTensorView(expected, obtained, o); + } + + if (!res) + { + return 255; + } + } + return status; } -- 2.7.4