#include "util/tensor/Index.h"
#include "util/tensor/Diff.h"
#include "util/tensor/Shape.h"
+#include "util/tensor/Comparator.h"
#include "support/tflite/TensorView.h"
#include <functional>
#include <vector>
-// NOTE The code below is subject to change.
-// TODO Introduce namespaces
-struct TfLiteTensorDiff : public nnfw::util::tensor::Diff<float>
-{
- TfLiteTensorDiff(const nnfw::util::tensor::Index &i) : nnfw::util::tensor::Diff<float>{i}
- {
- // DO NOTHING
- }
-};
-
-class TfLiteTensorComparator
-{
-public:
- TfLiteTensorComparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn}
- {
- // DO NOTHING
- }
-
-public:
- struct Observer
- {
- virtual void notify(const nnfw::util::tensor::Index &index, float expected, float obtained) = 0;
- };
-
-public:
- // NOTE Observer should live longer than comparator
- std::vector<TfLiteTensorDiff> compare(const nnfw::support::tflite::TensorView<float> &expected,
- const nnfw::support::tflite::TensorView<float> &obtained,
- Observer *observer = nullptr) const;
-
-private:
- std::function<bool (float lhs, float rhs)> _compare_fn;
-};
-
class TfLiteInterpMatchApp
{
public:
- TfLiteInterpMatchApp(const TfLiteTensorComparator &comparator)
+ TfLiteInterpMatchApp(const nnfw::util::tensor::Comparator &comparator)
: _verbose{false}, _comparator(comparator)
{
// DO NOTHING
int id) const;
private:
- const TfLiteTensorComparator &_comparator;
+ const nnfw::util::tensor::Comparator &_comparator;
};
#include "support/tflite/interp/Builder.h"
--- /dev/null
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_UTIL_TENSOR_COMPARATOR_H__
+#define __NNFW_UTIL_TENSOR_COMPARATOR_H__
+
+#include "util/tensor/Index.h"
+#include "util/tensor/Shape.h"
+#include "util/tensor/Reader.h"
+#include "util/tensor/Diff.h"
+
+#include <functional>
+
+#include <vector>
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+class Comparator
+{
+public:
+ Comparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn}
+ {
+ // DO NOTHING
+ }
+
+public:
+ struct Observer
+ {
+ virtual void notify(const Index &index, float expected, float obtained) = 0;
+ };
+
+public:
+ // NOTE Observer should live longer than comparator
+ std::vector<Diff<float>> compare(const Shape &shape,
+ const Reader<float> &expected,
+ const Reader<float> &obtained,
+ Observer *observer = nullptr) const;
+
+private:
+ std::function<bool (float lhs, float rhs)> _compare_fn;
+};
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw
+
+#endif // __NNFW_UTIL_TENSOR_COMPARATOR_H__
#include "util/tensor/IndexIterator.h"
#include "util/tensor/IndexFormatter.h"
#include "util/tensor/Zipper.h"
+#include "util/tensor/Comparator.h"
#include <iostream>
#include <cassert>
-class DiffSummary : public TfLiteTensorComparator::Observer
+class DiffSummary : public nnfw::util::tensor::Comparator::Observer
{
public:
DiffSummary()
}
}
-std::vector<TfLiteTensorDiff>
-TfLiteTensorComparator::compare(const nnfw::support::tflite::TensorView<float> &expected,
- const nnfw::support::tflite::TensorView<float> &obtained,
- Observer *observer) const
-{
- std::vector<TfLiteTensorDiff> res;
-
- assert(expected.shape() == obtained.shape());
-
- nnfw::util::tensor::zip(expected.shape(), expected, obtained) <<
- [&](const nnfw::util::tensor::Index &index, float expected_value, float obtained_value) {
- const auto relative_diff = nnfw::util::fp32::relative_diff(expected_value, obtained_value);
-
- if (!_compare_fn(expected_value, obtained_value))
- {
- TfLiteTensorDiff diff(index);
-
- diff.expected = expected_value;
- diff.obtained = obtained_value;
-
- res.emplace_back(diff);
- }
-
- // Update max_diff_index, if necessary
- if (observer != nullptr)
- {
- observer->notify(index, expected_value, obtained_value);
- }
- };
-
- return res;
-}
-
bool TfLiteInterpMatchApp::compareSingleTensorView(
const nnfw::support::tflite::TensorView<float> &expected,
const nnfw::support::tflite::TensorView<float> &obtained, int id) const
{
DiffSummary summary;
- auto diffs = _comparator.compare(expected, obtained, &summary);
+ assert(expected.shape() == obtained.shape());
+ auto diffs = _comparator.compare(expected.shape(), expected, obtained, &summary);
if (diffs.size() == 0)
{
return nnfw::util::fp32::epsilon_equal(lhs, rhs, tolerance);
};
- TfLiteTensorComparator comparator(equals);
+ nnfw::util::tensor::Comparator comparator(equals);
TfLiteInterpMatchApp app(comparator);
app.verbose() = _param.verbose;
list(APPEND NNFW_UTILITY_SRCS src/tensor/Shape.cpp)
list(APPEND NNFW_UTILITY_SRCS src/tensor/NonIncreasingStride.cpp)
list(APPEND NNFW_UTILITY_SRCS src/tensor/IndexFormatter.cpp)
+list(APPEND NNFW_UTILITY_SRCS src/tensor/Comparator.cpp)
add_library(nnfw_util SHARED ${NNFW_UTILITY_SRCS})
target_include_directories(nnfw_util PUBLIC ${NNFW_INCLUDE_DIR})
--- /dev/null
+#include "util/tensor/Comparator.h"
+#include "util/tensor/Zipper.h"
+
+#include "util/fp32.h"
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+std::vector<Diff<float>> Comparator::compare(const Shape &shape, const Reader<float> &expected,
+ const Reader<float> &obtained,
+ Observer *observer) const
+{
+ std::vector<Diff<float>> res;
+
+ zip(shape, expected, obtained) <<
+ [&](const Index &index, float expected_value, float obtained_value) {
+ const auto relative_diff = nnfw::util::fp32::relative_diff(expected_value, obtained_value);
+
+ if (!_compare_fn(expected_value, obtained_value))
+ {
+ Diff<float> diff(index);
+
+ diff.expected = expected_value;
+ diff.obtained = obtained_value;
+
+ res.emplace_back(diff);
+ }
+
+ // Update max_diff_index, if necessary
+ if (observer != nullptr)
+ {
+ observer->notify(index, expected_value, obtained_value);
+ }
+ };
+
+ return res;
+}
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw
return nnfw::util::fp32::epsilon_equal(lhs, rhs, tolerance);
};
- TfLiteTensorComparator comparator(equals);
+ nnfw::util::tensor::Comparator comparator(equals);
TfLiteInterpMatchApp app(comparator);
bool res = true;