From 0d4cf20dad647d2d9c4e1773f92a82f9b873b553 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Mon, 30 Apr 2018 17:35:25 +0900 Subject: [PATCH] [nncc.core] Add 'tensor::IndexRange' (#180) This commit adds 'tensor::IndexRange' class which helps users to iterage over all the index valid for a certain tensor shape. Signed-off-by: Jonghyun Park --- .../core/include/nncc/core/ADT/tensor/IndexRange.h | 38 +++++++++ libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp | 92 ++++++++++++++++++++++ .../src/nncc/core/ADT/tensor/IndexRange.test.cpp | 89 +++++++++++++++++++++ 3 files changed, 219 insertions(+) create mode 100644 libs/core/include/nncc/core/ADT/tensor/IndexRange.h create mode 100644 libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp create mode 100644 libs/core/src/nncc/core/ADT/tensor/IndexRange.test.cpp diff --git a/libs/core/include/nncc/core/ADT/tensor/IndexRange.h b/libs/core/include/nncc/core/ADT/tensor/IndexRange.h new file mode 100644 index 0000000..f28f3d0 --- /dev/null +++ b/libs/core/include/nncc/core/ADT/tensor/IndexRange.h @@ -0,0 +1,38 @@ +#ifndef __NNCC_CORE_ADT_TENSOR_INDEX_RANGE_H__ +#define __NNCC_CORE_ADT_TENSOR_INDEX_RANGE_H__ + +#include "nncc/core/ADT/tensor/Index.h" +#include "nncc/core/ADT/tensor/Shape.h" + +#include + +namespace nncc +{ +namespace core +{ +namespace ADT +{ +namespace tensor +{ + +class IndexRange +{ +public: + explicit IndexRange(const Shape &shape); + +public: + bool member(const Index &index) const; + +public: + void iterate(const std::function &) const; + +private: + const Shape _shape; +}; + +} // namespace tensor +} // namespace ADT +} // namespace core +} // namespace nncc + +#endif // __NNCC_CORE_ADT_TENSOR_INDEX_RANGE_H__ diff --git a/libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp b/libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp new file mode 100644 index 0000000..5636c25 --- /dev/null +++ b/libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp @@ -0,0 +1,92 @@ +#include "nncc/core/ADT/tensor/IndexRange.h" + +#include + +namespace nncc +{ +namespace core +{ +namespace ADT +{ +namespace tensor +{ + +IndexRange::IndexRange(const Shape &shape) : _shape(shape) +{ + // DO NOTHING +} + +bool IndexRange::member(const nncc::core::ADT::tensor::Index &index) const +{ + if (index.rank() != _shape.rank()) + { + return false; + } + + const auto rank = _shape.rank(); + + for (uint32_t axis = 0; axis < rank; ++axis) + { + if (!(index.at(axis) < _shape.dim(axis))) + { + return false; + } + } + + return true; +} + +void IndexRange::iterate(const std::function &f) const +{ + const auto rank = _shape.rank(); + + Index index; + + // Initialize index + index.resize(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + { + index.at(axis) = 0; + } + + if (!member(index)) + { + // Nothing to iterate + return; + } + + uint32_t cursor = 0; + + while (cursor < rank) + { + f(index); + + // Find axis to be updated + while ((cursor < rank) && !(index.at(cursor) + 1 < _shape.dim(cursor))) + { + ++cursor; + } + + // Skip update if cursor is out of valid range + if (cursor == rank) + { + continue; + } + + // Update index + index.at(cursor) += 1; + + for (uint32_t axis = 0; axis < cursor; ++axis) + { + index.at(axis) = 0; + } + + // Update cursor + cursor = 0; + } +} + +} // namespace tensor +} // namespace ADT +} // namespace core +} // namespace nncc diff --git a/libs/core/src/nncc/core/ADT/tensor/IndexRange.test.cpp b/libs/core/src/nncc/core/ADT/tensor/IndexRange.test.cpp new file mode 100644 index 0000000..556d64c --- /dev/null +++ b/libs/core/src/nncc/core/ADT/tensor/IndexRange.test.cpp @@ -0,0 +1,89 @@ +#include "nncc/core/ADT/tensor/IndexRange.h" + +#include +#include + +#include + +TEST(ADT_TENSOR_INDEX_RANGE, member_positive) +{ + nncc::core::ADT::tensor::Shape shape; + + shape.resize(1); + shape.dim(0) = 3; + + nncc::core::ADT::tensor::Index index; + + index.resize(1); + index.at(0) = 2; + + ASSERT_TRUE(nncc::core::ADT::tensor::IndexRange{shape}.member(index)); +} + +TEST(ADT_TENSOR_INDEX_RANGE, member_negative_unmatched_rank) +{ + nncc::core::ADT::tensor::Shape shape; + + shape.resize(1); + shape.dim(0) = 3; + + nncc::core::ADT::tensor::Index index; + + index.resize(2); + index.at(0) = 2; + index.at(1) = 3; + + ASSERT_FALSE(nncc::core::ADT::tensor::IndexRange{shape}.member(index)); +} + +TEST(ADT_TENSOR_INDEX_RANGE, member_negative_overflowed_dim) +{ + nncc::core::ADT::tensor::Shape shape; + + shape.resize(1); + shape.dim(0) = 3; + + nncc::core::ADT::tensor::Index index; + + index.resize(1); + index.at(0) = 4; + + ASSERT_FALSE(nncc::core::ADT::tensor::IndexRange{shape}.member(index)); +} + +TEST(ADT_TENSOR_INDEX_RANGE, iterate_empty_range) +{ + nncc::core::ADT::tensor::Shape shape; + + shape.resize(1); + shape.dim(0) = 0; + + uint32_t count = 0; + + auto f = [&count](const nncc::core::ADT::tensor::Index &) { ++count; }; + + nncc::core::ADT::tensor::IndexRange{shape}.iterate(f); + + ASSERT_EQ(count, 0); +} + +TEST(ADT_TENSOR_INDEX_RANGE, iterate_full_range) +{ + nncc::core::ADT::tensor::Shape shape; + + shape.resize(2); + shape.dim(0) = 3; + shape.dim(1) = 4; + + std::array count; + + count.fill(0); + + nncc::core::ADT::tensor::IndexRange{shape}.iterate( + [&count](const nncc::core::ADT::tensor::Index &i) { + ASSERT_EQ(i.rank(), 2); + count.at(i.at(0) * 4 + i.at(1)) += 1; + }); + + ASSERT_TRUE(std::all_of(count.begin(), count.end(), [](uint32_t n) { return n == 1; })); +} -- 2.7.4