This commit extracts stride-related code from TensorView in nnfw::support::tflite
as NonIncreasingStride utility class (in nnfw::util::tensor).
Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
#include "util/tensor/Shape.h"
#include "util/tensor/Index.h"
#include "util/tensor/Reader.h"
+#include "util/tensor/NonIncreasingStride.h"
namespace nnfw
{
float &at(const nnfw::util::tensor::Index &index);
private:
- uint32_t offsetOf(const nnfw::util::tensor::Index &index) const;
-
-private:
nnfw::util::tensor::Shape _shape;
public:
float *_base;
- std::vector<uint32_t> _stride;
+ nnfw::util::tensor::NonIncreasingStride _stride;
public:
// TODO Introduce Operand ID class
--- /dev/null
+#ifndef __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__
+#define __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__
+
+#include "util/tensor/Shape.h"
+#include "util/tensor/Index.h"
+
+#include <vector>
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+// As its name suggests, stride[N-1] >= stride[N] holds for all N < rank in NonIncreasingStride.
+class NonIncreasingStride
+{
+public:
+ void init(const Shape &shape)
+ {
+ _stride.resize(shape.rank());
+ _stride.at(shape.rank() - 1) = 1;
+
+ for (uint32_t axis = shape.rank() - 1; axis > 0; --axis)
+ {
+ _stride.at(axis - 1) = _stride.at(axis) * shape.dim(axis);
+ }
+ }
+
+public:
+ uint32_t offset(const Index &index) const;
+
+private:
+ std::vector<uint32_t> _stride;
+};
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw
+
+#endif // __NNFW_UTIL_TENSOR_NON_INCREASING_STRIDE_H__
TensorView<float>::TensorView(const nnfw::util::tensor::Shape &shape, float *base) : _shape{shape}, _base{base}
{
// Set 'stride'
- _stride.resize(_shape.rank());
- _stride.at(_shape.rank() - 1) = 1;
-
- for (uint32_t axis = _shape.rank() - 1; axis > 0; --axis)
- {
- _stride.at(axis - 1) = _stride.at(axis) * _shape.dim(axis);
- }
-}
-
-uint32_t TensorView<float>::offsetOf(const nnfw::util::tensor::Index &index) const
-{
- assert(index.rank() == _shape.rank());
-
- uint32_t offset = 0;
-
- for (size_t axis = 0; axis < _shape.rank(); ++axis)
- {
- offset += _stride.at(axis) * index.at(axis);
- }
-
- return offset;
+ _stride.init(_shape);
}
float TensorView<float>::at(const nnfw::util::tensor::Index &index) const
{
- const auto offset = offsetOf(index);
+ const auto offset = _stride.offset(index);
return *(_base + offset);
}
float &TensorView<float>::at(const nnfw::util::tensor::Index &index)
{
- const auto offset = offsetOf(index);
+ const auto offset = _stride.offset(index);
return *(_base + offset);
}
# Library `nnfw_util`
set(NNFW_UTILITY_SRCS src/environment.cpp)
list(APPEND NNFW_UTILITY_SRCS src/tensor/Shape.cpp)
+list(APPEND NNFW_UTILITY_SRCS src/tensor/NonIncreasingStride.cpp)
set(NNFW_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include)
--- /dev/null
+#include "util/tensor/NonIncreasingStride.h"
+
+#include <cassert>
+
+namespace nnfw
+{
+namespace util
+{
+namespace tensor
+{
+
+uint32_t NonIncreasingStride::offset(const Index &index) const
+{
+ const size_t rank = _stride.size();
+
+ assert(index.rank() == rank);
+
+ uint32_t offset = 0;
+
+ for (size_t axis = 0; axis < rank; ++axis)
+ {
+ offset += _stride.at(axis) * index.at(axis);
+ }
+
+ return offset;
+}
+
+} // namespace tensor
+} // namespace util
+} // namespace nnfw