Introduce TensorWrapper for unittesting (#619)
author김정현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh0822.kim@samsung.com>
Thu, 12 Apr 2018 08:03:17 +0000 (17:03 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 12 Apr 2018 08:03:17 +0000 (17:03 +0900)
This commit introduces TensorWrapper for unittesting.
This class can be used various test cases in `src/kernel`.

Signed-off-by: Junghyun Kim <jh0822.kim@samsung.com>
src/kernel/acl/src/util.h

index d5d0d41..74d4558 100644 (file)
+#ifndef __NNFW_KERNEL_ACL_UTIL_H__
+#define __NNFW_KERNEL_ACL_UTIL_H__
 #include <OperationsUtils.h>
 
+#include <cassert>
+#include <functional>
+
 namespace nnfw {
 namespace kernel {
 namespace acl {
 namespace util {
 
+// TODO: make a separate module.
+class TensorWrapper {
+public:
+  TensorWrapper(std::vector<uint32_t> dims,
+      OperandType type = OperandType::FLOAT32,
+      float scale = 1.0,
+      int32_t offset = 0)
+  :_shape{type, dims, scale, offset}
+  {
+
+    // currently, we support only FLOAT32 for now.
+    assert( type == OperandType::FLOAT32);
+
+    uint32_t size_bytes = sizeof(float);
+
+    _num_elems = 1;
+    for( auto& d: dims ) {
+      _num_elems *= d;
+    }
+
+    _data = new uint8_t[_num_elems * size_bytes];
+  }
+
+  ~TensorWrapper() {
+    delete [] _data;
+  }
+
+  android::nn::Shape shape() const {
+    return _shape;
+  }
+
+  uint32_t num_elems() const { return _num_elems; }
+
+  template<class T>
+  T at(const uint32_t& idx) const {
+    return reinterpret_cast<T*>(_data)[idx];
+  }
+
+  template<class T>
+  T& at(const uint32_t& idx) {
+    return reinterpret_cast<T*>(_data)[idx];
+  }
+
+  template<class T>
+  T* ptr() { return reinterpret_cast<T*>(_data); }
+
+  void initValue(float f) {
+    for( uint32_t i = 0; i < _num_elems; ++i ) {
+      at<float>(i) = f;
+    }
+  }
+
+  void initValue(std::function<float(uint32_t)> f) {
+    for( uint32_t i = 0; i < _num_elems; ++i ) {
+      at<float>(i) = f(i);
+    }
+  }
+
+  bool operator==(const TensorWrapper &t) const {
+    // compare the shape
+    assert(num_elems() == t.num_elems());
+    assert(_shape.type == t.shape().type);
+    assert(_shape.scale == t.shape().scale);
+    assert(_shape.offset == t.shape().offset);
+    assert(_shape.dimensions == t.shape().dimensions);
+
+    // currently, we support only FLOAT32.
+    assert(_shape.type == OperandType::FLOAT32);
+
+    for( uint32_t i = 0; i < _num_elems; ++i ) {
+      if( abs(at<float>(i) - t.at<float>(i)) > 0.001f ) {
+        std::cout << "Comparing [" << i << "] " << at<float>(i) << "," << t.at<float>(i) << std::endl;
+        return false;
+      }
+    }
+
+    return true;
+  }
+
+private:
+  android::nn::Shape _shape;
+  uint32_t           _num_elems;
+  uint8_t*           _data;
+};
+
 void initData(float* data, int num, float value);
 bool compareData(const float* result, const float* expected, const android::nn::Shape& shape);
 void initData_Increasing(float* data, int num, float value);
@@ -13,3 +103,5 @@ void initData_Increasing(float* data, int num, float value);
 } // namespace acl
 } // namespace kernel
 } // namespace nnfw
+
+#endif // __NNFW_KERNEL_ACL_UTIL_H__