[nncc.core] Add 'tensor::IndexRange' (#180)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 30 Apr 2018 08:35:25 +0000 (17:35 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 30 Apr 2018 08:35:25 +0000 (17:35 +0900)
This commit adds 'tensor::IndexRange' class which helps users to iterage
over all the index valid for a certain tensor shape.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
libs/core/include/nncc/core/ADT/tensor/IndexRange.h [new file with mode: 0644]
libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp [new file with mode: 0644]
libs/core/src/nncc/core/ADT/tensor/IndexRange.test.cpp [new file with mode: 0644]

diff --git a/libs/core/include/nncc/core/ADT/tensor/IndexRange.h b/libs/core/include/nncc/core/ADT/tensor/IndexRange.h
new file mode 100644 (file)
index 0000000..f28f3d0
--- /dev/null
@@ -0,0 +1,38 @@
+#ifndef __NNCC_CORE_ADT_TENSOR_INDEX_RANGE_H__
+#define __NNCC_CORE_ADT_TENSOR_INDEX_RANGE_H__
+
+#include "nncc/core/ADT/tensor/Index.h"
+#include "nncc/core/ADT/tensor/Shape.h"
+
+#include <functional>
+
+namespace nncc
+{
+namespace core
+{
+namespace ADT
+{
+namespace tensor
+{
+
+class IndexRange
+{
+public:
+  explicit IndexRange(const Shape &shape);
+
+public:
+  bool member(const Index &index) const;
+
+public:
+  void iterate(const std::function<void(const Index &)> &) const;
+
+private:
+  const Shape _shape;
+};
+
+} // namespace tensor
+} // namespace ADT
+} // namespace core
+} // namespace nncc
+
+#endif // __NNCC_CORE_ADT_TENSOR_INDEX_RANGE_H__
diff --git a/libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp b/libs/core/src/nncc/core/ADT/tensor/IndexRange.cpp
new file mode 100644 (file)
index 0000000..5636c25
--- /dev/null
@@ -0,0 +1,92 @@
+#include "nncc/core/ADT/tensor/IndexRange.h"
+
+#include <cassert>
+
+namespace nncc
+{
+namespace core
+{
+namespace ADT
+{
+namespace tensor
+{
+
+IndexRange::IndexRange(const Shape &shape) : _shape(shape)
+{
+  // DO NOTHING
+}
+
+bool IndexRange::member(const nncc::core::ADT::tensor::Index &index) const
+{
+  if (index.rank() != _shape.rank())
+  {
+    return false;
+  }
+
+  const auto rank = _shape.rank();
+
+  for (uint32_t axis = 0; axis < rank; ++axis)
+  {
+    if (!(index.at(axis) < _shape.dim(axis)))
+    {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+void IndexRange::iterate(const std::function<void(const Index &)> &f) const
+{
+  const auto rank = _shape.rank();
+
+  Index index;
+
+  // Initialize index
+  index.resize(rank);
+  for (uint32_t axis = 0; axis < rank; ++axis)
+  {
+    index.at(axis) = 0;
+  }
+
+  if (!member(index))
+  {
+    // Nothing to iterate
+    return;
+  }
+
+  uint32_t cursor = 0;
+
+  while (cursor < rank)
+  {
+    f(index);
+
+    // Find axis to be updated
+    while ((cursor < rank) && !(index.at(cursor) + 1 < _shape.dim(cursor)))
+    {
+      ++cursor;
+    }
+
+    // Skip update if cursor is out of valid range
+    if (cursor == rank)
+    {
+      continue;
+    }
+
+    // Update index
+    index.at(cursor) += 1;
+
+    for (uint32_t axis = 0; axis < cursor; ++axis)
+    {
+      index.at(axis) = 0;
+    }
+
+    // Update cursor
+    cursor = 0;
+  }
+}
+
+} // namespace tensor
+} // namespace ADT
+} // namespace core
+} // namespace nncc
diff --git a/libs/core/src/nncc/core/ADT/tensor/IndexRange.test.cpp b/libs/core/src/nncc/core/ADT/tensor/IndexRange.test.cpp
new file mode 100644 (file)
index 0000000..556d64c
--- /dev/null
@@ -0,0 +1,89 @@
+#include "nncc/core/ADT/tensor/IndexRange.h"
+
+#include <array>
+#include <algorithm>
+
+#include <gtest/gtest.h>
+
+TEST(ADT_TENSOR_INDEX_RANGE, member_positive)
+{
+  nncc::core::ADT::tensor::Shape shape;
+
+  shape.resize(1);
+  shape.dim(0) = 3;
+
+  nncc::core::ADT::tensor::Index index;
+
+  index.resize(1);
+  index.at(0) = 2;
+
+  ASSERT_TRUE(nncc::core::ADT::tensor::IndexRange{shape}.member(index));
+}
+
+TEST(ADT_TENSOR_INDEX_RANGE, member_negative_unmatched_rank)
+{
+  nncc::core::ADT::tensor::Shape shape;
+
+  shape.resize(1);
+  shape.dim(0) = 3;
+
+  nncc::core::ADT::tensor::Index index;
+
+  index.resize(2);
+  index.at(0) = 2;
+  index.at(1) = 3;
+
+  ASSERT_FALSE(nncc::core::ADT::tensor::IndexRange{shape}.member(index));
+}
+
+TEST(ADT_TENSOR_INDEX_RANGE, member_negative_overflowed_dim)
+{
+  nncc::core::ADT::tensor::Shape shape;
+
+  shape.resize(1);
+  shape.dim(0) = 3;
+
+  nncc::core::ADT::tensor::Index index;
+
+  index.resize(1);
+  index.at(0) = 4;
+
+  ASSERT_FALSE(nncc::core::ADT::tensor::IndexRange{shape}.member(index));
+}
+
+TEST(ADT_TENSOR_INDEX_RANGE, iterate_empty_range)
+{
+  nncc::core::ADT::tensor::Shape shape;
+
+  shape.resize(1);
+  shape.dim(0) = 0;
+
+  uint32_t count = 0;
+
+  auto f = [&count](const nncc::core::ADT::tensor::Index &) { ++count; };
+
+  nncc::core::ADT::tensor::IndexRange{shape}.iterate(f);
+
+  ASSERT_EQ(count, 0);
+}
+
+TEST(ADT_TENSOR_INDEX_RANGE, iterate_full_range)
+{
+  nncc::core::ADT::tensor::Shape shape;
+
+  shape.resize(2);
+  shape.dim(0) = 3;
+  shape.dim(1) = 4;
+
+  std::array<uint32_t, 3 * 4> count;
+
+  count.fill(0);
+
+  nncc::core::ADT::tensor::IndexRange{shape}.iterate(
+      [&count](const nncc::core::ADT::tensor::Index &i) {
+        ASSERT_EQ(i.rank(), 2);
+        count.at(i.at(0) * 4 + i.at(1)) += 1;
+      });
+
+  ASSERT_TRUE(std::all_of(count.begin(), count.end(), [](uint32_t n) { return n == 1; }));
+}