[PureACL] Support other type of operand for setting tensorinput in execution (#1642)
author장지섭/동작제어Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Mon, 11 Jun 2018 06:19:57 +0000 (15:19 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 11 Jun 2018 06:19:57 +0000 (15:19 +0900)
This commit supports other type of operand for setting tensorinput in execution.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/pure_arm_compute/src/execution.cc
runtimes/pure_arm_compute/src/internal/TensorSource.h

index 1bf7a8b..9810382 100644 (file)
@@ -104,6 +104,36 @@ static void asVectorSource(ANeuralNetworksExecution *execution, int32_t type, in
   }
 }
 
+static void asTensorSource(ANeuralNetworksExecution *execution, int32_t type, int32_t index,
+                           const nnfw::util::tensor::Shape &shape, const void *buffer,
+                           size_t length)
+{
+  switch (type)
+  {
+    case ANEURALNETWORKS_FLOAT32:
+    case ANEURALNETWORKS_TENSOR_FLOAT32:
+      execution->source<TensorSource<float>>(index, shape, reinterpret_cast<const float *>(buffer),
+                                             length);
+      break;
+    case ANEURALNETWORKS_INT32:
+    case ANEURALNETWORKS_TENSOR_INT32:
+      execution->source<TensorSource<int32_t>>(index, shape,
+                                               reinterpret_cast<const int32_t *>(buffer), length);
+      break;
+    case ANEURALNETWORKS_UINT32:
+      execution->source<TensorSource<uint32_t>>(index, shape,
+                                                reinterpret_cast<const uint32_t *>(buffer), length);
+      break;
+    case ANEURALNETWORKS_TENSOR_QUANT8_ASYMM:
+      execution->source<TensorSource<uint8_t>>(index, shape,
+                                               reinterpret_cast<const uint8_t *>(buffer), length);
+      break;
+    default:
+      throw std::runtime_error("Not supported, yet");
+      break;
+  }
+}
+
 static void asFeatureSource(ANeuralNetworksExecution *execution, int32_t type, int32_t index,
                             const nnfw::util::feature::Shape &shape, const void *buffer,
                             size_t length)
@@ -267,8 +297,7 @@ int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution *execution, int32
   {
     const auto &operand_shape = operands.at(operand_index).shape().asTensor();
 
-    execution->source<TensorSource>(index, operand_shape, reinterpret_cast<const uint8_t *>(buffer),
-                                    length);
+    asTensorSource(execution, input_type, index, operand_shape, buffer, length);
   }
   else if (operands.at(operand_index).shape().rank() == 4)
   {
index dede71f..dc479db 100644 (file)
@@ -8,10 +8,10 @@
 #include "internal/nnapi/tensor/Reader.h"
 #include "internal/arm_compute/tensor/View.h"
 
-class TensorSource final : public Source
+template <typename T> class TensorSource final : public Source
 {
 public:
-  TensorSource(const nnfw::util::tensor::Shape &shape, const uint8_t *base, const size_t size)
+  TensorSource(const nnfw::util::tensor::Shape &shape, const T *base, const size_t size)
       : _shape{shape}, _base{base}, _size{size}
   {
     // DO NOTHING
@@ -20,8 +20,12 @@ public:
 public:
   void push(::arm_compute::ITensor &tensor) const override
   {
-    const ::internal::nnapi::tensor::Reader<float> from{_shape, _base, _size};
-    ::internal::arm_compute::tensor::View<float> into{&tensor};
+    // TODO Should replace the Construct parameter of Reader and View from uint8_t * with typename
+    // T.
+    // Inevitably casting must be done.
+    const ::internal::nnapi::tensor::Reader<T> from{
+        _shape, reinterpret_cast<const uint8_t *>(_base), _size};
+    ::internal::arm_compute::tensor::View<T> into{&tensor};
 
     ::nnfw::util::tensor::iterate(_shape) << [&](const nnfw::util::tensor::Index &index_nnapi) {
       const auto value = from.at(index_nnapi);
@@ -34,7 +38,7 @@ public:
 
 private:
   const nnfw::util::tensor::Shape _shape;
-  const uint8_t *const _base;
+  const T *const _base;
   const size_t _size;
 };