Generic Tensor Comparator (#1462)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Thu, 31 May 2018 06:54:11 +0000 (15:54 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Thu, 31 May 2018 06:54:11 +0000 (15:54 +0900)
This commit revises TfLiteTensorComparator as generic tensor comparator,
and move it to nnfw_util.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
include/support/tflite/Diff.h
include/util/tensor/Comparator.h [new file with mode: 0644]
libs/support/tflite/src/Diff.cpp
libs/util/CMakeLists.txt
libs/util/src/tensor/Comparator.cpp [new file with mode: 0644]
tools/tflite_run/src/tflite_run.cc

index 26d141a..f17c549 100644 (file)
 #include "util/tensor/Index.h"
 #include "util/tensor/Diff.h"
 #include "util/tensor/Shape.h"
+#include "util/tensor/Comparator.h"
 
 #include "support/tflite/TensorView.h"
 
 #include <functional>
 #include <vector>
 
-// NOTE The code below is subject to change.
-// TODO Introduce namespaces
-struct TfLiteTensorDiff : public nnfw::util::tensor::Diff<float>
-{
-  TfLiteTensorDiff(const nnfw::util::tensor::Index &i) : nnfw::util::tensor::Diff<float>{i}
-  {
-    // DO NOTHING
-  }
-};
-
-class TfLiteTensorComparator
-{
-public:
-  TfLiteTensorComparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn}
-  {
-    // DO NOTHING
-  }
-
-public:
-  struct Observer
-  {
-    virtual void notify(const nnfw::util::tensor::Index &index, float expected, float obtained) = 0;
-  };
-
-public:
-  // NOTE Observer should live longer than comparator
-  std::vector<TfLiteTensorDiff> compare(const nnfw::support::tflite::TensorView<float> &expected,
-                                        const nnfw::support::tflite::TensorView<float> &obtained,
-                                        Observer *observer = nullptr) const;
-
-private:
-  std::function<bool (float lhs, float rhs)> _compare_fn;
-};
-
 class TfLiteInterpMatchApp
 {
 public:
-  TfLiteInterpMatchApp(const TfLiteTensorComparator &comparator)
+  TfLiteInterpMatchApp(const nnfw::util::tensor::Comparator &comparator)
     : _verbose{false}, _comparator(comparator)
   {
     // DO NOTHING
@@ -84,7 +51,7 @@ public:
                                int id) const;
 
 private:
-  const TfLiteTensorComparator &_comparator;
+  const nnfw::util::tensor::Comparator &_comparator;
 };
 
 #include "support/tflite/interp/Builder.h"
diff --git a/include/util/tensor/Comparator.h b/include/util/tensor/Comparator.h
new file mode 100644 (file)
index 0000000..f0ab4ab
--- /dev/null
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_UTIL_TENSOR_COMPARATOR_H__
+#define __NNFW_UTIL_TENSOR_COMPARATOR_H__
+
+#include "util/tensor/Index.h"
+#include "util/tensor/Shape.h"
+#include "util/tensor/Reader.h"
+#include "util/tensor/Diff.h"
+
+#include <functional>
+
+#include <vector>
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+class Comparator
+{
+public:
+  Comparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn}
+  {
+    // DO NOTHING
+  }
+
+public:
+  struct Observer
+  {
+    virtual void notify(const Index &index, float expected, float obtained) = 0;
+  };
+
+public:
+  // NOTE Observer should live longer than comparator
+  std::vector<Diff<float>> compare(const Shape &shape,
+                                        const Reader<float> &expected,
+                                        const Reader<float> &obtained,
+                                        Observer *observer = nullptr) const;
+
+private:
+  std::function<bool (float lhs, float rhs)> _compare_fn;
+};
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw
+
+#endif // __NNFW_UTIL_TENSOR_COMPARATOR_H__
index 4b994eb..8355fc3 100644 (file)
 #include "util/tensor/IndexIterator.h"
 #include "util/tensor/IndexFormatter.h"
 #include "util/tensor/Zipper.h"
+#include "util/tensor/Comparator.h"
 
 #include <iostream>
 #include <cassert>
 
-class DiffSummary : public TfLiteTensorComparator::Observer
+class DiffSummary : public nnfw::util::tensor::Comparator::Observer
 {
 public:
   DiffSummary()
@@ -74,46 +75,14 @@ void DiffSummary::notify(const nnfw::util::tensor::Index &index, float expected,
   }
 }
 
-std::vector<TfLiteTensorDiff>
-TfLiteTensorComparator::compare(const nnfw::support::tflite::TensorView<float> &expected,
-                                const nnfw::support::tflite::TensorView<float> &obtained,
-                                Observer *observer) const
-{
-  std::vector<TfLiteTensorDiff> res;
-
-  assert(expected.shape() == obtained.shape());
-
-  nnfw::util::tensor::zip(expected.shape(), expected, obtained) <<
-      [&](const nnfw::util::tensor::Index &index, float expected_value, float obtained_value) {
-        const auto relative_diff = nnfw::util::fp32::relative_diff(expected_value, obtained_value);
-
-        if (!_compare_fn(expected_value, obtained_value))
-        {
-          TfLiteTensorDiff diff(index);
-
-          diff.expected = expected_value;
-          diff.obtained = obtained_value;
-
-          res.emplace_back(diff);
-        }
-
-        // Update max_diff_index, if necessary
-        if (observer != nullptr)
-        {
-          observer->notify(index, expected_value, obtained_value);
-        }
-      };
-
-  return res;
-}
-
 bool TfLiteInterpMatchApp::compareSingleTensorView(
     const nnfw::support::tflite::TensorView<float> &expected,
     const nnfw::support::tflite::TensorView<float> &obtained, int id) const
 {
   DiffSummary summary;
 
-  auto diffs = _comparator.compare(expected, obtained, &summary);
+  assert(expected.shape() == obtained.shape());
+  auto diffs = _comparator.compare(expected.shape(), expected, obtained, &summary);
 
   if (diffs.size() == 0)
   {
@@ -271,7 +240,7 @@ int RandomTestRunner::run(const nnfw::support::tflite::interp::Builder &builder)
     return nnfw::util::fp32::epsilon_equal(lhs, rhs, tolerance);
   };
 
-  TfLiteTensorComparator comparator(equals);
+  nnfw::util::tensor::Comparator comparator(equals);
   TfLiteInterpMatchApp app(comparator);
 
   app.verbose() = _param.verbose;
index efd9829..52ed272 100644 (file)
@@ -3,6 +3,7 @@ set(NNFW_UTILITY_SRCS src/environment.cpp)
 list(APPEND NNFW_UTILITY_SRCS src/tensor/Shape.cpp)
 list(APPEND NNFW_UTILITY_SRCS src/tensor/NonIncreasingStride.cpp)
 list(APPEND NNFW_UTILITY_SRCS src/tensor/IndexFormatter.cpp)
+list(APPEND NNFW_UTILITY_SRCS src/tensor/Comparator.cpp)
 
 add_library(nnfw_util SHARED ${NNFW_UTILITY_SRCS})
 target_include_directories(nnfw_util PUBLIC ${NNFW_INCLUDE_DIR})
diff --git a/libs/util/src/tensor/Comparator.cpp b/libs/util/src/tensor/Comparator.cpp
new file mode 100644 (file)
index 0000000..3374bb7
--- /dev/null
@@ -0,0 +1,45 @@
+#include "util/tensor/Comparator.h"
+#include "util/tensor/Zipper.h"
+
+#include "util/fp32.h"
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+std::vector<Diff<float>> Comparator::compare(const Shape &shape, const Reader<float> &expected,
+                                             const Reader<float> &obtained,
+                                             Observer *observer) const
+{
+  std::vector<Diff<float>> res;
+
+  zip(shape, expected, obtained) <<
+      [&](const Index &index, float expected_value, float obtained_value) {
+        const auto relative_diff = nnfw::util::fp32::relative_diff(expected_value, obtained_value);
+
+        if (!_compare_fn(expected_value, obtained_value))
+        {
+          Diff<float> diff(index);
+
+          diff.expected = expected_value;
+          diff.obtained = obtained_value;
+
+          res.emplace_back(diff);
+        }
+
+        // Update max_diff_index, if necessary
+        if (observer != nullptr)
+        {
+          observer->notify(index, expected_value, obtained_value);
+        }
+      };
+
+  return res;
+}
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw
index cff3d36..ede6eb4 100644 (file)
@@ -196,7 +196,7 @@ int main(const int argc, char **argv)
       return nnfw::util::fp32::epsilon_equal(lhs, rhs, tolerance);
     };
 
-    TfLiteTensorComparator comparator(equals);
+    nnfw::util::tensor::Comparator comparator(equals);
     TfLiteInterpMatchApp app(comparator);
     bool res = true;