[nnc core] Add ShapeRange class (#394)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@partner.samsung.com>
Mon, 2 Jul 2018 15:02:51 +0000 (19:02 +0400)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Mon, 2 Jul 2018 15:02:51 +0000 (00:02 +0900)
[nnc core] Add ShapeRange class

Adds ShapeRange and ShapeIter classes used to iterate over given shape

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/libs/core/include/nnc/core/linalg/ShapeRange.h [new file with mode: 0644]
contrib/nnc/libs/core/src/core/linalg/ShapeRange.test.cpp [new file with mode: 0644]

diff --git a/contrib/nnc/libs/core/include/nnc/core/linalg/ShapeRange.h b/contrib/nnc/libs/core/include/nnc/core/linalg/ShapeRange.h
new file mode 100644 (file)
index 0000000..0dad66e
--- /dev/null
@@ -0,0 +1,98 @@
+#ifndef _NNC_CORE_LINALG_SHAPE_RANGE_H_
+#define _NNC_CORE_LINALG_SHAPE_RANGE_H_
+
+#include "nncc/core/ADT/tensor/Shape.h"
+#include "nncc/core/ADT/tensor/Index.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace data
+{
+
+using nncc::core::ADT::tensor::Shape;
+using nncc::core::ADT::tensor::Index;
+
+class ShapeIter :
+    public std::iterator<std::forward_iterator_tag, Index, size_t, Index*, Index&> {
+ public:
+  ShapeIter& operator++() {
+    auto* pidx = &_index.at(0);
+    auto* pshape = &_shape.dim(0);
+    auto rank = _shape.rank();
+    int c = rank - 1;
+    pidx[c]++;
+    while( (pidx[c] > pshape[c] - 1) && (c > 0) ) {
+      pidx[c] = 0;
+      pidx[--c]++;
+    }
+    _pos++;
+    return *this;
+  }
+
+  ShapeIter operator++(int) {
+    ShapeIter it = *this;
+    ++*this;
+    return it;
+  }
+
+  const Index& operator*() const {
+    return _index;
+  }
+
+  bool operator!=(ShapeIter& iter) {
+    assert(iter._index.rank() == _index.rank());
+    assert(iter._shape == _shape);
+    return _pos != iter._pos;
+  }
+
+ private:
+  explicit ShapeIter(Shape &_shape, uint32_t pos) : _pos(pos), _shape(_shape) {
+    _index.resize(_shape.rank());
+    _index.fill(0);
+  }
+
+  uint32_t _pos;
+  Index _index;
+  Shape& _shape;
+
+  friend class ShapeRange;
+};
+
+class ShapeRange {
+ public:
+  explicit ShapeRange(const Shape &shape) : _shape(const_cast<Shape&>(shape))
+  {}
+
+  ShapeIter begin() {
+    return ShapeIter(_shape, 0);
+  }
+
+  ShapeIter end() {
+    uint32_t _end = 1;
+    for( uint32_t d = 0; d < _shape.rank(); ++d ) _end *= _shape.dim(d);
+    return ShapeIter(_shape, _end);
+  }
+
+  bool contains(const Index& idx) {
+    assert(idx.rank() == _shape.rank());
+    bool res = true;
+    for(uint32_t d = 0; d < idx.rank(); ++d ) {
+      res &= idx.at(d) < _shape.dim(d);
+    }
+    return res;
+  }
+
+ private:
+  Shape& _shape;
+};
+
+} // namespace data
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif //_NNC_CORE_LINALG_SHAPE_RANGE_H_
diff --git a/contrib/nnc/libs/core/src/core/linalg/ShapeRange.test.cpp b/contrib/nnc/libs/core/src/core/linalg/ShapeRange.test.cpp
new file mode 100644 (file)
index 0000000..99bce76
--- /dev/null
@@ -0,0 +1,44 @@
+#include "gtest/gtest.h"
+#include "nnc/core/linalg/ShapeRange.h"
+
+namespace {
+
+using namespace nncc::contrib::core::data;
+
+struct ParamType {
+  uint32_t actual_length;
+  Shape shape;
+
+  template<typename ...Args>
+  explicit ParamType(uint32_t actual_len, Args&& ...args) : actual_length(actual_len), shape({static_cast<uint32_t>(args)...}) {}
+};
+
+class ShapeIteratorTest : public ::testing::TestWithParam<ParamType> {
+
+};
+
+TEST_P(ShapeIteratorTest, ElementCount) {
+  Shape sh(GetParam().shape);
+  ShapeRange r(sh);
+
+  uint32_t cnt = 0;
+  for( auto& idx : r ) {
+    (void)idx;
+    cnt++;
+  }
+
+  ASSERT_EQ(cnt, GetParam().actual_length);
+}
+
+
+std::vector<ParamType> test_data {
+    ParamType{6,   1,2,3},
+    ParamType{16,  2,2,4},
+    ParamType{1,   1,1,1,1,1},
+    ParamType{5,   5,1,1,1,1,1}
+};
+
+
+INSTANTIATE_TEST_CASE_P(SimpleInput, ShapeIteratorTest, ::testing::ValuesIn(test_data));
+
+}