From 5959fd1bde3e1b04c0197138b512460070fb840d Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Mon, 9 Apr 2018 11:28:58 +0900 Subject: [PATCH] Introduce tensor::NonIncreasingStride (#499) This commit extracts stride-related code from TensorView in nnfw::support::tflite as NonIncreasingStride utility class (in nnfw::util::tensor). Signed-off-by: Jonghyun Park --- include/support/tflite/TensorView.h | 6 ++--- include/util/tensor/NonIncreasingStride.h | 42 +++++++++++++++++++++++++++++ src/support/tflite/src/TensorView.cpp | 26 +++--------------- src/util/CMakeLists.txt | 1 + src/util/src/tensor/NonIncreasingStride.cpp | 30 +++++++++++++++++++++ 5 files changed, 78 insertions(+), 27 deletions(-) create mode 100644 include/util/tensor/NonIncreasingStride.h create mode 100644 src/util/src/tensor/NonIncreasingStride.cpp diff --git a/include/support/tflite/TensorView.h b/include/support/tflite/TensorView.h index 8aeb325..18de5df 100644 --- a/include/support/tflite/TensorView.h +++ b/include/support/tflite/TensorView.h @@ -6,6 +6,7 @@ #include "util/tensor/Shape.h" #include "util/tensor/Index.h" #include "util/tensor/Reader.h" +#include "util/tensor/NonIncreasingStride.h" namespace nnfw { @@ -29,14 +30,11 @@ public: float &at(const nnfw::util::tensor::Index &index); private: - uint32_t offsetOf(const nnfw::util::tensor::Index &index) const; - -private: nnfw::util::tensor::Shape _shape; public: float *_base; - std::vector _stride; + nnfw::util::tensor::NonIncreasingStride _stride; public: // TODO Introduce Operand ID class diff --git a/include/util/tensor/NonIncreasingStride.h b/include/util/tensor/NonIncreasingStride.h new file mode 100644 index 0000000..30a5c1e --- /dev/null +++ b/include/util/tensor/NonIncreasingStride.h @@ -0,0 +1,42 @@ +#ifndef __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__ +#define __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__ + +#include "util/tensor/Shape.h" +#include "util/tensor/Index.h" + +#include + +namespace nnfw +{ +namespace util +{ +namespace tensor +{ + +// As its name suggests, stride[N-1] >= stride[N] holds for all N < rank in NonIncreasingStride. +class NonIncreasingStride +{ +public: + void init(const Shape &shape) + { + _stride.resize(shape.rank()); + _stride.at(shape.rank() - 1) = 1; + + for (uint32_t axis = shape.rank() - 1; axis > 0; --axis) + { + _stride.at(axis - 1) = _stride.at(axis) * shape.dim(axis); + } + } + +public: + uint32_t offset(const Index &index) const; + +private: + std::vector _stride; +}; + +} // namespace tensor +} // namespace util +} // namespace nnfw + +#endif // __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__ diff --git a/src/support/tflite/src/TensorView.cpp b/src/support/tflite/src/TensorView.cpp index 18acfe6..774d0cf 100644 --- a/src/support/tflite/src/TensorView.cpp +++ b/src/support/tflite/src/TensorView.cpp @@ -12,39 +12,19 @@ namespace tflite TensorView::TensorView(const nnfw::util::tensor::Shape &shape, float *base) : _shape{shape}, _base{base} { // Set 'stride' - _stride.resize(_shape.rank()); - _stride.at(_shape.rank() - 1) = 1; - - for (uint32_t axis = _shape.rank() - 1; axis > 0; --axis) - { - _stride.at(axis - 1) = _stride.at(axis) * _shape.dim(axis); - } -} - -uint32_t TensorView::offsetOf(const nnfw::util::tensor::Index &index) const -{ - assert(index.rank() == _shape.rank()); - - uint32_t offset = 0; - - for (size_t axis = 0; axis < _shape.rank(); ++axis) - { - offset += _stride.at(axis) * index.at(axis); - } - - return offset; + _stride.init(_shape); } float TensorView::at(const nnfw::util::tensor::Index &index) const { - const auto offset = offsetOf(index); + const auto offset = _stride.offset(index); return *(_base + offset); } float &TensorView::at(const nnfw::util::tensor::Index &index) { - const auto offset = offsetOf(index); + const auto offset = _stride.offset(index); return *(_base + offset); } diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 97ab891..046f01d 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -1,6 +1,7 @@ # Library `nnfw_util` set(NNFW_UTILITY_SRCS src/environment.cpp) list(APPEND NNFW_UTILITY_SRCS src/tensor/Shape.cpp) +list(APPEND NNFW_UTILITY_SRCS src/tensor/NonIncreasingStride.cpp) set(NNFW_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include) diff --git a/src/util/src/tensor/NonIncreasingStride.cpp b/src/util/src/tensor/NonIncreasingStride.cpp new file mode 100644 index 0000000..54b6f41 --- /dev/null +++ b/src/util/src/tensor/NonIncreasingStride.cpp @@ -0,0 +1,30 @@ +#include "util/tensor/NonIncreasingStride.h" + +#include + +namespace nnfw +{ +namespace util +{ +namespace tensor +{ + +uint32_t NonIncreasingStride::offset(const Index &index) const +{ + const size_t rank = _stride.size(); + + assert(index.rank() == rank); + + uint32_t offset = 0; + + for (size_t axis = 0; axis < rank; ++axis) + { + offset += _stride.at(axis) * index.at(axis); + } + + return offset; +} + +} // namespace tensor +} // namespace util +} // namespace nnfw -- 2.7.4