From 4efd593626487f939b2296cabf5aed53ffc8d01a Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Mon, 9 Apr 2018 09:30:41 +0900 Subject: [PATCH] Introduce nnfw::util::tensor::Zipper (#495) This commit introduces Zipper class and related helper methods. Signed-off-by: Jonghyun Park --- include/util/tensor/Zipper.h | 56 ++++++++++++++++++++++++++++++++++++++ tools/nnapi_test/src/nnapi_test.cc | 6 ++-- 2 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 include/util/tensor/Zipper.h diff --git a/include/util/tensor/Zipper.h b/include/util/tensor/Zipper.h new file mode 100644 index 0000000..d0e0c46 --- /dev/null +++ b/include/util/tensor/Zipper.h @@ -0,0 +1,56 @@ +#ifndef __NNFW_UTIL_TENSOR_ZIPPER_H__ +#define __NNFW_UTIL_TENSOR_ZIPPER_H__ + +#include "util/tensor/Index.h" +#include "util/tensor/IndexIterator.h" +#include "util/tensor/Reader.h" + +namespace nnfw +{ +namespace util +{ +namespace tensor +{ + +template class Zipper +{ +public: + Zipper(const Shape &shape, const Reader &lhs, const Reader &rhs) + : _shape{shape}, _lhs{lhs}, _rhs{rhs} + { + // DO NOTHING + } + +public: + template void zip(Callable cb) const + { + iterate(_shape) << [this, &cb] (const Index &index) + { + cb(index, _lhs.at(index), _rhs.at(index)); + }; + } + +private: + const Shape &_shape; + const Reader &_lhs; + const Reader &_rhs; +}; + +template +const Zipper &operator<<(const Zipper &zipper, Callable cb) +{ + zipper.zip(cb); + return zipper; +} + +template +Zipper zip(const Shape &shape, const Reader &lhs, const Reader &rhs) +{ + return Zipper{shape, lhs, rhs}; +} + +} // namespace tensor +} // namespace util +} // namespace nnfw + +#endif // __NNFW_UTIL_TENSOR_ZIPPER_H__ diff --git a/tools/nnapi_test/src/nnapi_test.cc b/tools/nnapi_test/src/nnapi_test.cc index ae8ed05..c5c7d9d 100644 --- a/tools/nnapi_test/src/nnapi_test.cc +++ b/tools/nnapi_test/src/nnapi_test.cc @@ -4,6 +4,7 @@ #include "util/environment.h" #include "util/fp32.h" #include "util/tensor/IndexIterator.h" +#include "util/tensor/Zipper.h" #include "support/tflite/TensorView.h" #include "support/tflite/interp/FlatBufferBuilder.h" @@ -117,10 +118,9 @@ TfLiteTensorComparator::compare(const nnfw::support::tflite::TensorView & assert(expected.shape() == obtained.shape()); - nnfw::util::tensor::iterate(expected.shape()) << [&] (const nnfw::util::tensor::Index &index) + nnfw::util::tensor::zip(expected.shape(), expected, obtained) << + [&] (const nnfw::util::tensor::Index &index, float expected_value, float obtained_value) { - const auto expected_value = expected.at(index); - const auto obtained_value = obtained.at(index); const auto relative_diff = nnfw::util::fp32::relative_diff(expected_value, obtained_value); if (!_compare_fn(expected_value, obtained_value)) -- 2.7.4