Introduce tensor::NonIncreasingStride (#499)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 9 Apr 2018 02:28:58 +0000 (11:28 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Mon, 9 Apr 2018 02:28:58 +0000 (11:28 +0900)
This commit extracts stride-related code from TensorView in nnfw::support::tflite
as NonIncreasingStride utility class (in nnfw::util::tensor).

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
include/support/tflite/TensorView.h
include/util/tensor/NonIncreasingStride.h [new file with mode: 0644]
src/support/tflite/src/TensorView.cpp
src/util/CMakeLists.txt
src/util/src/tensor/NonIncreasingStride.cpp [new file with mode: 0644]

index 8aeb325..18de5df 100644 (file)
@@ -6,6 +6,7 @@
 #include "util/tensor/Shape.h"
 #include "util/tensor/Index.h"
 #include "util/tensor/Reader.h"
+#include "util/tensor/NonIncreasingStride.h"
 
 namespace nnfw
 {
@@ -29,14 +30,11 @@ public:
   float &at(const nnfw::util::tensor::Index &index);
 
 private:
-  uint32_t offsetOf(const nnfw::util::tensor::Index &index) const;
-
-private:
   nnfw::util::tensor::Shape _shape;
 
 public:
   float *_base;
-  std::vector<uint32_t> _stride;
+  nnfw::util::tensor::NonIncreasingStride _stride;
 
 public:
   // TODO Introduce Operand ID class
diff --git a/include/util/tensor/NonIncreasingStride.h b/include/util/tensor/NonIncreasingStride.h
new file mode 100644 (file)
index 0000000..30a5c1e
--- /dev/null
@@ -0,0 +1,42 @@
+#ifndef __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__
+#define __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__
+
+#include "util/tensor/Shape.h"
+#include "util/tensor/Index.h"
+
+#include <vector>
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+// As its name suggests, stride[N-1] >= stride[N] holds for all N < rank in NonIncreasingStride.
+class NonIncreasingStride
+{
+public:
+  void init(const Shape &shape)
+  {
+    _stride.resize(shape.rank());
+    _stride.at(shape.rank() - 1) = 1;
+
+    for (uint32_t axis = shape.rank() - 1; axis > 0; --axis)
+    {
+      _stride.at(axis - 1) = _stride.at(axis) * shape.dim(axis);
+    }
+  }
+
+public:
+  uint32_t offset(const Index &index) const;
+
+private:
+  std::vector<uint32_t> _stride;
+};
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw
+
+#endif // __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__
index 18acfe6..774d0cf 100644 (file)
@@ -12,39 +12,19 @@ namespace tflite
 TensorView<float>::TensorView(const nnfw::util::tensor::Shape &shape, float *base) : _shape{shape}, _base{base}
 {
   // Set 'stride'
-  _stride.resize(_shape.rank());
-  _stride.at(_shape.rank() - 1) = 1;
-
-  for (uint32_t axis = _shape.rank() - 1; axis > 0; --axis)
-  {
-    _stride.at(axis - 1) = _stride.at(axis) * _shape.dim(axis);
-  }
-}
-
-uint32_t TensorView<float>::offsetOf(const nnfw::util::tensor::Index &index) const
-{
-  assert(index.rank() == _shape.rank());
-
-  uint32_t offset = 0;
-
-  for (size_t axis = 0; axis < _shape.rank(); ++axis)
-  {
-    offset += _stride.at(axis) * index.at(axis);
-  }
-
-  return offset;
+  _stride.init(_shape);
 }
 
 float TensorView<float>::at(const nnfw::util::tensor::Index &index) const
 {
-  const auto offset = offsetOf(index);
+  const auto offset = _stride.offset(index);
 
   return *(_base + offset);
 }
 
 float &TensorView<float>::at(const nnfw::util::tensor::Index &index)
 {
-  const auto offset = offsetOf(index);
+  const auto offset = _stride.offset(index);
 
   return *(_base + offset);
 }
index 97ab891..046f01d 100644 (file)
@@ -1,6 +1,7 @@
 # Library `nnfw_util`
 set(NNFW_UTILITY_SRCS src/environment.cpp)
 list(APPEND NNFW_UTILITY_SRCS src/tensor/Shape.cpp)
+list(APPEND NNFW_UTILITY_SRCS src/tensor/NonIncreasingStride.cpp)
 
 set(NNFW_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include)
 
diff --git a/src/util/src/tensor/NonIncreasingStride.cpp b/src/util/src/tensor/NonIncreasingStride.cpp
new file mode 100644 (file)
index 0000000..54b6f41
--- /dev/null
@@ -0,0 +1,30 @@
+#include "util/tensor/NonIncreasingStride.h"
+
+#include <cassert>
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+uint32_t NonIncreasingStride::offset(const Index &index) const
+{
+  const size_t rank = _stride.size();
+
+  assert(index.rank() == rank);
+
+  uint32_t offset = 0;
+
+  for (size_t axis = 0; axis < rank; ++axis)
+  {
+    offset += _stride.at(axis) * index.at(axis);
+  }
+
+  return offset;
+}
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw