[Pure ACL NN Runtime] Implement setOperand functions (#490)
author박종현/동작제어Lab(SR)/Senior Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 9 Apr 2018 01:14:18 +0000 (10:14 +0900)
committer서상민/동작제어Lab(SR)/Senior Engineer/삼성전자 <sangmin7.seo@samsung.com>
Mon, 9 Apr 2018 01:14:18 +0000 (10:14 +0900)
This commit implements setOperandValue and  setOperandValueFromMemory
functions for pure ACL NN Runtime.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.cc
tools/nnapi_bindings/bindings/pure_arm_compute/src/internal/Model.h
tools/nnapi_bindings/bindings/pure_arm_compute/src/memory.h
tools/nnapi_bindings/bindings/pure_arm_compute/src/model.cc

index bd1cbb0..580c925 100644 (file)
@@ -37,6 +37,11 @@ const Object &Set::at(const Index &index) const
   return *(_objects.at(index.asInt()));
 }
 
+Object &Set::at(const Index &index)
+{
+  return *(_objects.at(index.asInt()));
+}
+
 } // namespace operand
 } // namespace tflite
 } // namespace internal
index dad3ef3..a593d17 100644 (file)
@@ -57,6 +57,70 @@ private:
 } // namespace tflite
 } // namespace internal
 
+#include <algorithm>
+
+namespace internal
+{
+namespace tflite
+{
+namespace operand
+{
+
+struct Data
+{
+  virtual ~Data() = default;
+
+  virtual size_t size(void) const = 0;
+  virtual const uint8_t *base(void) const = 0;
+};
+
+class CachedData final : public Data
+{
+public:
+  CachedData(const uint8_t *base, size_t size) : _base{new uint8_t[size]}, _size{size}
+  {
+    std::copy(base, base + size, _base);
+  }
+
+public:
+  ~CachedData()
+  {
+    delete[] _base;
+  }
+
+public:
+  size_t size(void) const override { return _size; }
+  const uint8_t *base(void) const override { return _base; }
+
+private:
+  uint8_t *_base;
+  size_t _size;
+
+};
+
+class ExternalData final : public Data
+{
+public:
+  ExternalData(const uint8_t *base, size_t size) : _base{base}, _size{size}
+  {
+    // DO NOTHING
+  }
+
+public:
+  size_t size(void) const override { return _size; }
+  const uint8_t *base(void) const override { return _base; }
+
+private:
+  const uint8_t *_base;
+  const size_t _size;
+};
+
+} // namespace operand
+} // namespace tflite
+} // namespace internal
+
+#include <memory>
+
 namespace internal
 {
 namespace tflite
@@ -76,7 +140,17 @@ public:
   const Shape &shape(void) const { return _shape; }
 
 private:
+  void data(std::unique_ptr<Data> &&data) { _data = std::move(data); }
+
+public:
+  template <typename T, typename... Args> void data(Args&&... args)
+  {
+    data(std::unique_ptr<T>(new T{std::forward<Args>(args)...}));
+  }
+
+private:
   const Shape _shape;
+  std::unique_ptr<Data> _data;
 };
 
 } // namespace operand
@@ -99,6 +173,7 @@ public:
 
 public:
   const Object &at(const Index &) const;
+  Object &at(const Index &);
 
 private:
   std::vector<std::unique_ptr<Object>> _objects;
index fc0453c..fdd7080 100644 (file)
@@ -12,6 +12,7 @@ public:
 public:
   size_t size(void) const { return _size; }
   uint8_t *base(void) { return _base; }
+  const uint8_t *base(void) const { return _base; }
 
 private:
   size_t _size;
index 7d276c4..df587e4 100644 (file)
@@ -3,6 +3,7 @@
 #include <cassert>
 
 #include "model.h"
+#include "memory.h"
 
 ResultCode
 ANeuralNetworksModel_create(ANeuralNetworksModel** model)
@@ -48,6 +49,13 @@ ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model,
                                      int32_t index,
                                      const void* buffer, size_t length)
 {
+  const internal::tflite::operand::Index ind{index};
+  auto &obj = model->deref().operands().at(ind);
+
+  using internal::tflite::operand::CachedData;
+
+  obj.data<CachedData>(reinterpret_cast<const uint8_t *>(buffer), length);
+
   return ANEURALNETWORKS_NO_ERROR;
 }
 
@@ -56,6 +64,13 @@ ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel* model,
                                                int32_t index,
                                                const ANeuralNetworksMemory* memory, size_t offset, size_t length)
 {
+  const internal::tflite::operand::Index ind{index};
+  auto &obj = model->deref().operands().at(ind);
+
+  using internal::tflite::operand::ExternalData;
+
+  obj.data<ExternalData>(reinterpret_cast<const uint8_t *>(memory->base() + offset), length);
+
   return ANEURALNETWORKS_NO_ERROR;
 }