[new_runtime] Support CPU backend (#1847)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Wed, 4 Jul 2018 09:32:47 +0000 (18:32 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 4 Jul 2018 09:32:47 +0000 (18:32 +0900)
Conv2D Operation is only supported yet.
Implement #1607 into the current runtime structure.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
21 files changed:
runtimes/new_runtime/CMakeLists.txt
runtimes/new_runtime/src/compilation.cc
runtimes/new_runtime/src/execution.cc
runtimes/new_runtime/src/internal/BackendManager.cc
runtimes/new_runtime/src/internal/Model.h
runtimes/new_runtime/src/internal/cpu.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/InitializerGenerator.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/InitializerGenerator.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/MemoryAllocator.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/MemoryAllocator.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/StageGenerator.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/StageGenerator.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/TensorBuilder.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/cpu/TensorBuilder.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/kernels/cpufallback/CPUConvolutionLayer.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/kernels/cpufallback/CPUConvolutionLayer.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/nnapi/kernel/View.h [new file with mode: 0644]
runtimes/new_runtime/src/model.cc

index 5dd2a48..a8ab694 100644 (file)
@@ -1,12 +1,13 @@
 file(GLOB_RECURSE SOURCES "src/*.cc")
 
 add_library(${LIB_NEW_RUNTIME} SHARED ${SOURCES})
-target_include_directories(${LIB_NEW_RUNTIME} PUBLIC ${NNFW_INCLUDE_DIR}
-                           ${CMAKE_SOURCE_DIR}/externals/acl
-                           ${CMAKE_SOURCE_DIR}/externals/acl/include)
-target_include_directories(${LIB_NEW_RUNTIME} PUBLIC src)
+target_include_directories(${LIB_NEW_RUNTIME} PUBLIC ${NNFW_INCLUDE_DIR})
+target_include_directories(${LIB_NEW_RUNTIME} PUBLIC src
+                           ${CMAKE_SOURCE_DIR}/externals/tensorflow)
 target_link_libraries(${LIB_NEW_RUNTIME} arm_compute)
+target_link_libraries(${LIB_NEW_RUNTIME} tensorflow-lite)
 target_link_libraries(${LIB_NEW_RUNTIME} nnfw_util)
+target_link_libraries(${LIB_NEW_RUNTIME} nnfw_support_nnapi)
 set_target_properties(${LIB_NEW_RUNTIME} PROPERTIES OUTPUT_NAME neuralnetworks)
 
 install(TARGETS ${LIB_NEW_RUNTIME} DESTINATION lib/new_runtime)
index 9e13729..af6ff18 100644 (file)
@@ -77,9 +77,10 @@ public:
   BackendResolver(::internal::BackendManager &backend_manager)
   {
     auto acl_gen = backend_manager.get("arm_compute");
+    auto cpu_gen = backend_manager.get("cpu");
 
     // TODO Set generator map according to environment variable
-    _gen_map[typeid(::internal::tflite::op::Conv2D::implicit::Node)] = acl_gen;
+    _gen_map[typeid(::internal::tflite::op::Conv2D::implicit::Node)] = cpu_gen;
     _gen_map[typeid(::internal::tflite::op::MaxPool2D::implicit::Node)] = acl_gen;
     _gen_map[typeid(::internal::tflite::op::AvgPool2D::implicit::Node)] = acl_gen;
     _gen_map[typeid(::internal::tflite::op::Concat::Node)] = acl_gen;
index e3b9a91..a5718bb 100644 (file)
@@ -6,6 +6,9 @@
 #include "internal/nnapi/feature/Reader.h"
 #include "internal/nnapi/feature/View.h"
 
+#include "internal/cpu.h"
+#include <arm_compute/runtime/CL/CLTensor.h>
+
 #include "internal/arm_compute/feature/View.h"
 
 #include "util/feature/IndexIterator.h"
@@ -58,13 +61,27 @@ public:
 public:
   void push(::arm_compute::ITensor &tensor) const override
   {
-    const ::internal::nnapi::feature::Reader<float> from{_shape, _base, _size};
-    ::internal::arm_compute::feature::View<float> into{&tensor};
+    // TODO: This is just workaround codes, It needs to refactor.
+    if (typeid(tensor) == typeid(::internal::cpu::Tensor))
+    {
+      const ::internal::nnapi::feature::Reader<float> from{_shape, _base, _size};
+      ::internal::nnapi::feature::View<float> into{_shape, tensor.buffer(), _size};
+
+      ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+        const auto value = from.at(ch, row, col);
+        into.at(ch, row, col) = value;
+      };
+    }
+    else if (typeid(tensor) == typeid(::arm_compute::CLTensor))
+    {
+      const ::internal::nnapi::feature::Reader<float> from{_shape, _base, _size};
+      ::internal::arm_compute::feature::View<float> into{&tensor};
 
-    ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
-      const auto value = from.at(ch, row, col);
-      into.at(ch, row, col) = value;
-    };
+      ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+        const auto value = from.at(ch, row, col);
+        into.at(ch, row, col) = value;
+      };
+    }
   }
 
 private:
@@ -118,13 +135,27 @@ public:
 public:
   void pull(::arm_compute::ITensor &tensor) const override
   {
-    const ::internal::arm_compute::feature::View<float> from{&tensor};
-    ::internal::nnapi::feature::View<float> into{_shape, _base, _size};
+    // TODO: This is just workaround codes, It needs to refactor.
+    if (typeid(tensor) == typeid(::internal::cpu::Tensor))
+    {
+      const ::internal::nnapi::feature::Reader<float> from{_shape, tensor.buffer(), _size};
+      ::internal::nnapi::feature::View<float> into{_shape, _base, _size};
 
-    ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
-      const auto value = from.at(ch, row, col);
-      into.at(ch, row, col) = value;
-    };
+      ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+        const auto value = from.at(ch, row, col);
+        into.at(ch, row, col) = value;
+      };
+    }
+    else if (typeid(tensor) == typeid(::arm_compute::CLTensor))
+    {
+      const ::internal::arm_compute::feature::View<float> from{&tensor};
+      ::internal::nnapi::feature::View<float> into{_shape, _base, _size};
+
+      ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+        const auto value = from.at(ch, row, col);
+        into.at(ch, row, col) = value;
+      };
+    }
   }
 
 private:
index ca11ad1..28d5c02 100644 (file)
@@ -3,6 +3,9 @@
 #include "internal/arm_compute/TensorBuilder.h"
 #include "internal/arm_compute/InitializerGenerator.h"
 #include "internal/arm_compute/StageGenerator.h"
+#include "internal/cpu/TensorBuilder.h"
+#include "internal/cpu/InitializerGenerator.h"
+#include "internal/cpu/StageGenerator.h"
 
 namespace internal
 {
@@ -20,7 +23,14 @@ BackendManager::BackendManager(::internal::arm_compute::Plan& plan) : _plan(plan
     _gen_map["arm_compute"] = {acl_initializer_gen, acl_stage_gen};
   }
 
-  // TODO Add CPU backend
+  // Add CPU backend
+  {
+    auto cpu_tensor_builder = std::make_shared<::internal::cpu::TensorBuilder>(_plan);
+    auto cpu_initializer_gen = std::make_shared<::internal::cpu::InitializerGenerator>(operands);
+    auto cpu_stage_gen = std::make_shared<::internal::cpu::StageGenerator>(operands, cpu_tensor_builder);
+
+    _gen_map["cpu"] = {cpu_initializer_gen, cpu_stage_gen};
+  }
 }
 
 Backend BackendManager::get(const std::string &key)
index 8357bb8..629290a 100644 (file)
@@ -51,6 +51,14 @@ public:
 public:
   int32_t dim(uint32_t n) const { return _dims.at(n); }
   int32_t &dim(uint32_t n) { return _dims.at(n); }
+  const std::vector<int32_t> &dims() const { return _dims; }
+  int32_t type() const { return _type; }
+  float scale() const { return _scale; }
+  void set(int32_t type, float scale)
+  {
+    _type = type;
+    _scale = scale;
+  }
 
 public:
   int32_t asVector(void) const;
@@ -59,6 +67,8 @@ public:
 
 private:
   std::vector<int32_t> _dims;
+  int32_t _type;
+  float _scale;
 };
 
 } // namespace operand
diff --git a/runtimes/new_runtime/src/internal/cpu.cc b/runtimes/new_runtime/src/internal/cpu.cc
new file mode 100644 (file)
index 0000000..48ee872
--- /dev/null
@@ -0,0 +1,17 @@
+#include "internal/cpu.h"
+
+namespace internal
+{
+namespace cpu
+{
+namespace operand
+{
+
+void Object::access(const std::function<void(::arm_compute::ITensor &tensor)> &fn) const
+{
+  fn(*_tensor);
+}
+
+} // namespace operand
+} // namepsace cpu
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/cpu.h b/runtimes/new_runtime/src/internal/cpu.h
new file mode 100644 (file)
index 0000000..3e5b6d5
--- /dev/null
@@ -0,0 +1,125 @@
+#ifndef __INTERNAL_CPU_H__
+#define __INTERNAL_CPU_H__
+
+#include <arm_compute/core/ITensor.h>
+#include <arm_compute/core/TensorInfo.h>
+#include <arm_compute/core/CL/OpenCL.h>
+
+#include <unistd.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <iostream>
+
+#include "internal/IObject.h"
+
+namespace internal
+{
+namespace cpu
+{
+namespace operand
+{
+
+class Object : public ::internal::IObject
+{
+public:
+  Object() = default;
+
+public:
+  Object(const std::shared_ptr<::arm_compute::ITensor> &tensor) : _tensor{tensor}
+  {
+    // DO NOTHING
+  }
+
+public:
+  ::arm_compute::ITensor *ptr(void) const override { return _tensor.get(); }
+
+private:
+  std::shared_ptr<::arm_compute::ITensor> _tensor;
+
+public:
+  void access(const std::function<void(::arm_compute::ITensor &tensor)> &fn) const override;
+};
+
+} // namespace operand
+} // namepsace cpu
+} // namespace internal
+
+
+namespace internal
+{
+namespace cpu
+{
+
+#define PATH_MAX 256
+
+static int shmem_num = 0;
+static int shmem_create_region(size_t size)
+{
+    char temp[PATH_MAX];
+    snprintf(temp, sizeof(temp), "/tmp/nn-shmem-%d-%d-XXXXXXXXX", getpid(), shmem_num++);
+    int fd = mkstemp(temp);
+    if (fd == -1) return -1;
+
+    unlink(temp);
+
+    if (TEMP_FAILURE_RETRY(ftruncate(fd, size)) == -1) {
+      close(fd);
+      return -1;
+    }
+
+    return fd;
+}
+
+class Tensor : public ::arm_compute::ITensor
+{
+public:
+  Tensor() = default;
+
+  Tensor(::arm_compute::TensorInfo info) : _info(info)
+  {
+    uint32_t size = _info.total_size();
+    // TODO Do not use shared memory
+    int fd = shmem_create_region(size);
+    _buffer = reinterpret_cast<uint8_t *>(mmap(0, size, PROT_WRITE | PROT_READ, MAP_PRIVATE, fd, 0));
+  }
+
+  Tensor(uint8_t *buffer) : _buffer(buffer)
+  {
+    // DO NOTHING
+  }
+
+public:
+
+  void setBuffer(uint8_t *buffer)
+  {
+       _buffer = buffer;
+  }
+
+public:
+  ::arm_compute::TensorInfo *info() const override
+  {
+    return const_cast<::arm_compute::TensorInfo*>(&_info);
+  }
+
+  ::arm_compute::TensorInfo *info() override
+  {
+    return &_info;
+  }
+
+  uint8_t *buffer() const override
+  {
+         return _buffer;
+  }
+
+private:
+  ::arm_compute::TensorInfo _info;
+  uint8_t *_buffer = nullptr;
+};
+
+} // cpu
+} // internal
+
+#endif // __INTERNAL_CPU_H__
diff --git a/runtimes/new_runtime/src/internal/cpu/InitializerGenerator.cc b/runtimes/new_runtime/src/internal/cpu/InitializerGenerator.cc
new file mode 100644 (file)
index 0000000..e47dcf7
--- /dev/null
@@ -0,0 +1,72 @@
+#include "internal/cpu/InitializerGenerator.h"
+
+#include "internal/nnapi/kernel/Reader.h"
+#include "internal/nnapi/kernel/View.h"
+#include "util/kernel/IndexIterator.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+InitializerGenerator::InitializerGenerator(const ::internal::tflite::operand::Set &ctx) : _ctx(ctx)
+{
+  // DO NOTHING
+}
+
+Initializer InitializerGenerator::generateWeight(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+  const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
+
+  const auto ker_shape = _ctx.at(ker_index).shape().asKernel();
+  auto ker_base = _ctx.at(ker_index).data().base();
+  auto ker_size = _ctx.at(ker_index).data().size();
+
+  return [ker_shape, ker_base, ker_size](::arm_compute::ITensor &tensor) {
+    const ::internal::nnapi::kernel::Reader<float> from{ker_shape, ker_base, ker_size};
+    ::internal::nnapi::kernel::View<float> into{&tensor};
+
+    ::nnfw::util::kernel::iterate(ker_shape)
+        << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
+             const auto value = from.at(nth, ch, row, col);
+             into.at(nth, row, col, ch) = value;
+           };
+  };
+}
+
+Initializer InitializerGenerator::generateWeight(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+Initializer InitializerGenerator::generateBias(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+  // TODO Refactor so we can reuse the common code
+
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+  auto bias_base = _ctx.at(bias_index).data().base();
+  const auto bias_size = _ctx.at(bias_index).shape().asVector();
+
+  return [bias_base, bias_size](::arm_compute::ITensor &tensor) {
+    for (uint32_t n = 0; n < bias_size; ++n)
+    {
+      const ::arm_compute::Coordinates coordinate{n};
+
+      float *into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
+
+      const float *from = reinterpret_cast<const float *>(bias_base) + n;
+      const auto value = *from;
+
+      *into = value;
+    }
+  };
+}
+
+Initializer InitializerGenerator::generateBias(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+} // namespace arm_compute
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/cpu/InitializerGenerator.h b/runtimes/new_runtime/src/internal/cpu/InitializerGenerator.h
new file mode 100644 (file)
index 0000000..be18ae3
--- /dev/null
@@ -0,0 +1,31 @@
+#ifndef __INTERNAL_CPU_INITIALIZER_GENERATOR_H__
+#define __INTERNAL_CPU_INITIALIZER_GENERATOR_H__
+
+#include "internal/IInitializerGenerator.h"
+
+#include "internal/Model.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+class InitializerGenerator : public ::internal::IInitializerGenerator
+{
+public:
+  InitializerGenerator(const ::internal::tflite::operand::Set &ctx);
+
+  Initializer generateWeight(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+  Initializer generateWeight(const ::internal::tflite::op::FullyConnected::Node &node) override;
+
+  Initializer generateBias(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+  Initializer generateBias(const ::internal::tflite::op::FullyConnected::Node &node) override;
+
+private:
+  const ::internal::tflite::operand::Set &_ctx;
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_INITIALIZER_GENERATOR_H__
diff --git a/runtimes/new_runtime/src/internal/cpu/MemoryAllocator.cc b/runtimes/new_runtime/src/internal/cpu/MemoryAllocator.cc
new file mode 100644 (file)
index 0000000..bdf58d7
--- /dev/null
@@ -0,0 +1,2 @@
+//#include "internal/cpu/MemoryAllocator.h"
+
diff --git a/runtimes/new_runtime/src/internal/cpu/MemoryAllocator.h b/runtimes/new_runtime/src/internal/cpu/MemoryAllocator.h
new file mode 100644 (file)
index 0000000..dbefc80
--- /dev/null
@@ -0,0 +1,103 @@
+#ifndef __INTERNAL_CPU_MEMORY_ALLOCATOR_H__
+#define __INTERNAL_CPU_MEMORY_ALLOCATOR_H__
+
+#include "arm_compute/runtime/ITensorAllocator.h"
+#include "arm_compute/runtime/Memory.h"
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+namespace arm_compute
+{
+class Coordinates;
+class TensorInfo;
+class Tensor;
+};
+
+/** Basic implementation of a CPU memory tensor allocator. */
+class TensorAllocator : public ITensorAllocator
+{
+public:
+    /** Default constructor. */
+    TensorAllocator(Tensor *owner = nullptr);
+    /** Default destructor */
+    ~TensorAllocator();
+
+    /** Make ITensorAllocator's init methods available */
+    using ITensorAllocator::init;
+
+    /** Shares the same backing memory with another tensor allocator, while the tensor info might be different.
+     *  In other words this can be used to create a sub-tensor from another tensor while sharing the same memory.
+     *
+     * @note TensorAllocator have to be of the same specialized type.
+     *
+     * @param[in] allocator The allocator that owns the backing memory to be shared. Ownership becomes shared afterwards.
+     * @param[in] coords    The starting coordinates of the new tensor inside the parent tensor.
+     * @param[in] sub_info  The new tensor information (e.g. shape etc)
+     */
+    void init(const TensorAllocator &allocator, const Coordinates &coords, TensorInfo sub_info);
+
+    /** Returns the pointer to the allocated data. */
+    uint8_t *data() const;
+
+    /** Allocate size specified by TensorInfo of CPU memory.
+     *
+     * @note The tensor must not already be allocated when calling this function.
+     *
+     */
+    void allocate() override;
+
+    /** Free allocated CPU memory.
+     *
+     * @note The tensor must have been allocated when calling this function.
+     *
+     */
+    void free() override;
+    /** Import an existing memory as a tensor's backing memory
+     *
+     * @warning If the tensor is flagged to be managed by a memory manager,
+     *          this call will lead to an error.
+     * @warning Ownership of memory depends on the way the @ref Memory object was constructed
+     * @note    Calling free on a tensor with imported memory will just clear
+     *          the internal pointer value.
+     *
+     * @param[in] memory Memory to import
+     *
+     * @return error status
+     */
+    arm_compute::Status import_memory(Memory memory);
+    /** Associates the tensor with a memory group
+     *
+     * @param[in] associated_memory_group Memory group to associate the tensor with
+     */
+    void set_associated_memory_group(MemoryGroup *associated_memory_group);
+
+protected:
+    /** No-op for CPU memory
+     *
+     * @return A pointer to the beginning of the tensor's allocation.
+     */
+    uint8_t *lock() override;
+
+    /** No-op for CPU memory. */
+    void unlock() override;
+
+private:
+    MemoryGroup *_associated_memory_group; /**< Registered memory manager */
+    Memory       _memory;                  /**< CPU memory */
+    Tensor      *_owner;                   /**< Owner of the allocator */
+};
+
+namespace internal
+{
+namespace cpu
+{
+
+class MemoryAllocator : public {
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_MEMORY_ALLOCATOR_H__
diff --git a/runtimes/new_runtime/src/internal/cpu/StageGenerator.cc b/runtimes/new_runtime/src/internal/cpu/StageGenerator.cc
new file mode 100644 (file)
index 0000000..e9b4f7e
--- /dev/null
@@ -0,0 +1,135 @@
+#include "internal/cpu/StageGenerator.h"
+
+#include <stdexcept>
+
+#include "internal/Padding.h"
+#include "internal/kernels/cpufallback/CPUConvolutionLayer.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+StageGenerator::StageGenerator(const ::internal::tflite::operand::Set &operand_ctx,
+                               const std::shared_ptr<::internal::cpu::TensorBuilder> &tensor_builder)
+  : _ctx(operand_ctx), _tensor_builder(tensor_builder)
+{
+  // DO NOTHING
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+  const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+  const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+  const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
+  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
+
+  const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
+  const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
+
+  const PaddingCode padding_type =
+      static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
+
+  assert((ANEURALNETWORKS_PADDING_SAME == padding_type) ||
+         (ANEURALNETWORKS_PADDING_VALID == padding_type));
+
+  Stride stride;
+
+  stride.vertical = _ctx.at(vstride_index).asScalar<int32_t>();
+  stride.horizontal = _ctx.at(hstride_index).asScalar<int32_t>();
+
+  // Construct operation parameters
+  struct Param
+  {
+    int ofm_index;
+    int ifm_index;
+    int ker_index;
+    int bias_index;
+
+    ::internal::tflite::operand::Shape ofm_shape{1};
+    ::internal::tflite::operand::Shape ifm_shape{1};
+    ::internal::tflite::operand::Shape ker_shape{1};
+    ::internal::tflite::operand::Shape bias_shape{1};
+
+    Padding padding;
+    Stride stride;
+
+    FuseCode activation;
+  };
+
+  Param param;
+
+  param.ofm_index = ofm_index.asInt();
+  param.ifm_index = ifm_index.asInt();
+  param.ker_index = ker_index.asInt();
+  param.bias_index = bias_index.asInt();
+
+  param.ofm_shape = _ctx.at(ofm_index).shape();
+  param.ifm_shape = _ctx.at(ifm_index).shape();
+  param.ker_shape = _ctx.at(ker_index).shape();
+  param.bias_shape = _ctx.at(bias_index).shape();
+
+  param.stride = stride;
+  param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
+                      ? same_padding(param.ifm_shape.asFeature(), param.ofm_shape.asFeature(),
+                                     stride, param.ker_shape.asKernel().W,
+                                     param.ker_shape.asKernel().H)
+                      : valid_padding();
+
+  param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScalar<int32_t>());
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto ofm_alloc = tensors->at(::internal::tflite::operand::Index{param.ofm_index});
+    auto ifm_alloc = tensors->at(::internal::tflite::operand::Index{param.ifm_index});
+    auto ker_alloc = tensors->at(::internal::tflite::operand::Index{param.ker_index});
+    auto bias_alloc = tensors->at(::internal::tflite::operand::Index{param.bias_index});
+
+    std::unique_ptr<::internal::kernels::cpu::CPUConvolutionLayer> fn{
+        new ::internal::kernels::cpu::CPUConvolutionLayer};
+
+    fn->configure(ifm_alloc->buffer(), param.ifm_shape, ker_alloc->buffer(), param.ker_shape, bias_alloc->buffer(),
+                  param.bias_shape, param.padding.left, param.padding.right, param.padding.top, param.padding.bottom,
+                  param.stride.horizontal, param.stride.vertical, param.activation, ofm_alloc->buffer(),
+                  param.ofm_shape);
+
+    builder.append(std::move(fn));
+  };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::MaxPool2D::implicit::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::AvgPool2D::implicit::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Concat::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Reshape::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Softmax::Node &node)
+{
+  throw std::runtime_error("NYI");
+}
+
+} // namespace stage
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/cpu/StageGenerator.h b/runtimes/new_runtime/src/internal/cpu/StageGenerator.h
new file mode 100644 (file)
index 0000000..bed0498
--- /dev/null
@@ -0,0 +1,39 @@
+#ifndef __INTERNAL_CPU_STAGE_GENERATOR_H__
+#define __INTERNAL_CPU_STAGE_GENERATOR_H__
+
+#include "internal/IStageGenerator.h"
+
+#include "internal/Model.h"
+#include "internal/cpu.h"
+#include "internal/cpu/TensorBuilder.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+class StageGenerator : public ::internal::IStageGenerator
+{
+public:
+  StageGenerator(const ::internal::tflite::operand::Set &ctx,
+                 const std::shared_ptr<::internal::cpu::TensorBuilder> &tensor_builder);
+
+  virtual std::shared_ptr<ITensorBuilder> tensor_builder() override { return _tensor_builder; }
+
+  virtual Stage generate(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::MaxPool2D::implicit::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::AvgPool2D::implicit::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::Concat::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::FullyConnected::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::Reshape::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::Softmax::Node &node) override;
+
+private:
+  const ::internal::tflite::operand::Set &_ctx;
+  std::shared_ptr<::internal::cpu::TensorBuilder> _tensor_builder;
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_STAGE_GENERATOR_H__
diff --git a/runtimes/new_runtime/src/internal/cpu/TensorBuilder.cc b/runtimes/new_runtime/src/internal/cpu/TensorBuilder.cc
new file mode 100644 (file)
index 0000000..5709eac
--- /dev/null
@@ -0,0 +1,53 @@
+#include "internal/cpu/TensorBuilder.h"
+
+#include <cassert>
+
+#include "internal/arm_compute.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+TensorBuilder::TensorBuilder(::internal::arm_compute::Plan &plan) : _plan(plan)
+{
+  // DO NOTHING
+}
+
+void TensorBuilder::mark(const ::internal::tflite::operand::Index& ind)
+{
+  assert(_tensors.size() == 0);
+
+  _inds.insert(ind.asInt());
+}
+
+void TensorBuilder::prepare(const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx)
+{
+  assert(_tensors.size() == 0);
+
+  for (auto ind_int : _inds)
+  {
+    ::internal::tflite::operand::Index ind{ind_int};
+    auto tensor = std::make_shared<::internal::cpu::Tensor>(tensor_info_ctx.at(ind.asInt()));
+    // TODO Fix allocation here. When Tensor object is created the memory for tensor is also
+    //      allocated, and this must be fixed.
+    _plan.operands().set(ind, std::make_shared<::internal::cpu::operand::Object>(tensor));
+    _tensors[ind.asInt()] = tensor;
+  }
+}
+
+void TensorBuilder::allocate(void)
+{
+  assert(_inds.size() == _tensors.size());
+
+  // NOTE For now nothing to do. Allocation is done in prepare stage, which is wrong
+  //      See also: comment in `prepare()`
+}
+
+std::shared_ptr<::internal::cpu::Tensor> TensorBuilder::at(const ::internal::tflite::operand::Index &ind)
+{
+  return _tensors.at(ind.asInt());
+}
+
+} // namespace cpu
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/cpu/TensorBuilder.h b/runtimes/new_runtime/src/internal/cpu/TensorBuilder.h
new file mode 100644 (file)
index 0000000..c95bdee
--- /dev/null
@@ -0,0 +1,38 @@
+#ifndef __INTERNAL_CPU_TENSOR_BUILDER_H__
+#define __INTERNAL_CPU_TENSOR_BUILDER_H__
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "internal/ITensorBuilder.h"
+#include "internal/cpu.h"
+#include "internal/arm_compute.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+class Plan;
+
+class TensorBuilder : public ::internal::ITensorBuilder
+{
+public:
+  TensorBuilder(::internal::arm_compute::Plan &plan);
+
+  virtual void mark(const ::internal::tflite::operand::Index& ind) override;
+  virtual void prepare(const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx) override;
+  virtual void allocate(void) override;
+
+  std::shared_ptr<::internal::cpu::Tensor> at(const ::internal::tflite::operand::Index &ind);
+
+private:
+  ::internal::arm_compute::Plan &_plan;
+  std::unordered_set<int> _inds;
+  std::unordered_map<int, std::shared_ptr<::internal::cpu::Tensor>> _tensors;
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_TENSOR_BUILDER_H__
diff --git a/runtimes/new_runtime/src/internal/kernels/cpufallback/CPUConvolutionLayer.cc b/runtimes/new_runtime/src/internal/kernels/cpufallback/CPUConvolutionLayer.cc
new file mode 100644 (file)
index 0000000..4e38d61
--- /dev/null
@@ -0,0 +1,173 @@
+#include "CPUConvolutionLayer.h"
+
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+// If possible we will use this static buffer for the tensor.
+static constexpr int kStaticBufferSize = 1605632;
+static char static_scratch_buffer[kStaticBufferSize];
+
+#define ANDROID_NN_CONV_PARAMETERS(Type)                                        \
+    uint32_t height       = getSizeOfDimension(_inputShape, 1);                  \
+    uint32_t width        = getSizeOfDimension(_inputShape, 2);                  \
+    uint32_t kernelHeight = getSizeOfDimension(_kernelShape, 1);                 \
+    uint32_t kernelWidth  = getSizeOfDimension(_kernelShape, 2);                 \
+    uint32_t outHeight    = getSizeOfDimension(_outputShape, 1);                 \
+    uint32_t outWidth     = getSizeOfDimension(_outputShape, 2);                 \
+    uint32_t inDepth      = getSizeOfDimension(_inputShape, 3);                  \
+                                                                                \
+    uint32_t paddingHeight = (uint32_t)_paddingTop;                             \
+    uint32_t paddingWidth = (uint32_t)_paddingLeft;                             \
+                                                                                \
+    ::tflite::Dims<4> im2colDim;                                                  \
+    im2colDim.sizes[3] = (int)getSizeOfDimension(_outputShape, 0);               \
+    im2colDim.sizes[2] = (int)getSizeOfDimension(_outputShape, 1);               \
+    im2colDim.sizes[1] = (int)getSizeOfDimension(_outputShape, 2);               \
+    im2colDim.sizes[0] = (int)inDepth * kernelHeight * kernelWidth;             \
+                                                                                \
+    im2colDim.strides[0] = 1;                                                   \
+    for (int i=1; i<4; i++) {                                                   \
+        im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1];   \
+    }                                                                           \
+    Type* im2colData = nullptr;                                                 \
+    uint64_t im2colByteSize = sizeof(Type);                                     \
+    std::unique_ptr<Type[]> im2colGuard;                                        \
+    for (int i=0; i<4; i++) {                                                   \
+        im2colByteSize *= im2colDim.sizes[i];                                   \
+    }                                                                           \
+    /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */   \
+    if (im2colByteSize >= 0x7fffffff)  {                                        \
+        std::cout << "Conv size is too large, not enough memory" << std::endl;  \
+        return false;                                                           \
+    }                                                                           \
+    if (im2colByteSize <= kStaticBufferSize) {                                  \
+        im2colData = reinterpret_cast<Type *>(static_scratch_buffer);           \
+    } else {                                                                    \
+        im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)];    \
+        if (im2colData == nullptr) {                                            \
+            std::cout << "Conv size is too large, not enough memory" << std::endl; \
+            return false;                                                       \
+        }                                                                       \
+        im2colGuard.reset(im2colData);                                          \
+    }
+
+bool CPUConvolutionLayer::convFloat32()
+{
+  ANDROID_NN_CONV_PARAMETERS(float)
+
+  float output_activation_min, output_activation_max;
+  CalculateActivationRangeFloat(_activation, &output_activation_min,
+                                &output_activation_max);
+  int32_t dilationWidthFactor = 1, dilationHeightFactor = 1;
+  ::tflite::optimized_ops::Conv(
+          reinterpret_cast<const float *>(_inputData), convertShapeToDims(_inputShape),
+          reinterpret_cast<const float *>(_kernelData), convertShapeToDims(_kernelShape),
+          reinterpret_cast<const float *>(_biasData), convertShapeToDims(_biasShape),
+          _strideWidth, _strideHeight,
+          dilationWidthFactor, dilationHeightFactor,
+          paddingWidth, paddingHeight,
+          output_activation_min, output_activation_max,
+          reinterpret_cast<float *>(_outputData), convertShapeToDims(_outputShape),
+          im2colData, im2colDim);
+  return true;
+}
+
+bool CPUConvolutionLayer::convQuant8()
+{
+/*
+  ANDROID_NN_CONV_PARAMETERS(uint8_t)
+
+  int32_t inputOffset = -inputShape.offset;
+  int32_t filterOffset = -filterShape.offset;
+  int32_t outputOffset = outputShape.offset;
+
+  float real_multiplier = 0.0;
+  int32_t output_multiplier = 0;
+  int32_t output_shift = 0;
+  int32_t output_activation_min = 0;
+  int32_t output_activation_max = 0;
+
+  if (!GetQuantizedConvolutionMultipler(inputShape, filterShape, biasShape, outputShape,
+                                        &real_multiplier) ||
+      !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift))
+  {
+    // Following code inserted to resolve Coverity (118950 Resource leak)
+    if (im2colByteSize > kStaticBufferSize)
+    {
+      delete[] im2colData;
+    }
+    return false;
+  }
+  CalculateActivationRangeUint8(activation, outputShape, &output_activation_min,
+                                &output_activation_max);
+
+  static gemmlowp::GemmContext gemm_context;
+  // Alow gemmlowp automatcally decide how many threads to use.
+  gemm_context.set_max_num_threads(0);
+
+#define ANDROID_NN_CONV(activation)                                                              \
+  optimized_ops::Conv<FusedActivationFunctionType::activation>(                                  \
+      inputData, convertShapeToDims(inputShape), inputOffset, filterData,                        \
+      convertShapeToDims(filterShape), filterOffset, biasData, convertShapeToDims(biasShape),    \
+      stride_width, stride_height, paddingWidth, paddingHeight, outputOffset, output_multiplier, \
+      output_shift, output_activation_min, output_activation_max, outputData,                    \
+      convertShapeToDims(outputShape), im2colData, im2colDim, &gemm_context)
+
+  ANDROID_NN_MACRO_DISPATCH_WITH_DELETE(ANDROID_NN_CONV)
+#undef ANDROID_NN_CONV
+
+  if (im2colByteSize > kStaticBufferSize)
+  {
+    delete[] im2colData;
+  }*/
+  return true;
+}
+
+void CPUConvolutionLayer::configure(uint8_t *inputData, const internal::tflite::operand::Shape inputShape, uint8_t *kernelData,
+                const internal::tflite::operand::Shape kernelShape, uint8_t *biasData, const internal::tflite::operand::Shape biasShape,
+                const uint32_t paddingLeft, const uint32_t paddingRight, const uint32_t paddingTop, const uint32_t paddingBottom,
+                const uint32_t strideWidth, const uint32_t strideHeight, const FuseCode activation, uint8_t *outputData, const internal::tflite::operand::Shape outputShape)
+{
+  _inputData = inputData;
+  _inputShape = convertShape(inputShape);
+  _kernelData = kernelData;
+  _kernelShape = convertShape(kernelShape);
+  _biasData = biasData;
+  _biasShape = convertShape(biasShape);
+  _paddingLeft = paddingLeft;
+  _paddingRight = paddingRight;
+  _paddingTop = paddingTop;
+  _paddingBottom = paddingBottom;
+  _strideWidth = strideWidth;
+  _strideHeight = strideHeight;
+  _activation = activation;
+  _outputData = outputData;
+  _outputShape = convertShape(outputShape);
+}
+
+
+void CPUConvolutionLayer::run()
+{
+  convFloat32();
+  /*
+  if (input.type == OperandType::TENSOR_FLOAT32)
+  {
+  }
+  else if (input.type == OperandType::TENSOR_QUANT8_ASYMM)
+  {
+  }
+  */
+}
+
+#undef ANDROID_NN_CONV_PARAMETERS
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/kernels/cpufallback/CPUConvolutionLayer.h b/runtimes/new_runtime/src/internal/kernels/cpufallback/CPUConvolutionLayer.h
new file mode 100644 (file)
index 0000000..39c1fdf
--- /dev/null
@@ -0,0 +1,65 @@
+#ifndef __INTERNAL_KERNELS_CPU_CPUCONVOLUTIONLAYER_H__
+#define __INTERNAL_KERNELS_CPU_CPUCONVOLUTIONLAYER_H__
+
+#include <NeuralNetworks.h>
+
+#include <arm_compute/runtime/IFunction.h>
+
+#include "internal/Model.h"
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+using namespace internal::kernels::cpu;
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+class CPUConvolutionLayer : public ::arm_compute::IFunction
+{
+public:
+  CPUConvolutionLayer()
+  {
+  }
+
+public:
+  bool convFloat32();
+
+  bool convQuant8();
+
+  void configure(uint8_t *inputData, const internal::tflite::operand::Shape inputShape, uint8_t *kernelData,
+                const internal::tflite::operand::Shape kernelShape, uint8_t *biasData, const internal::tflite::operand::Shape biasShape,
+                const uint32_t paddingLeft, const uint32_t paddingRight, const uint32_t paddingTop, const uint32_t paddingBottom,
+                const uint32_t strideW, const uint32_t strideH, const FuseCode activation, uint8_t *outputData, const internal::tflite::operand::Shape outputShape);
+
+  void run();
+
+private:
+  uint8_t *_inputData;
+  uint8_t *_kernelData;
+  uint8_t *_outputData;
+  uint8_t *_biasData;
+
+  Shape _inputShape;
+  Shape _kernelShape;
+  Shape _outputShape;
+  Shape _biasShape;
+
+  uint32_t _paddingLeft;
+  uint32_t _paddingTop;
+  uint32_t _paddingRight;
+  uint32_t _paddingBottom;
+
+  uint32_t _strideWidth;
+  uint32_t _strideHeight;
+
+  FuseCode _activation;
+};
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
+
+#endif // __INTERNAL_KERNELS_CPU_CPUCONVOLUTIONLAYER_H__
diff --git a/runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.cc b/runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.cc
new file mode 100644 (file)
index 0000000..75802c6
--- /dev/null
@@ -0,0 +1,62 @@
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+uint32_t getSizeOfDimension(const Shape &shape, uint32_t dimensionIdx)
+{
+  if (dimensionIdx >= shape.dimensions.size())
+  {
+    // TODO, log the error
+    return 0;
+  }
+  return shape.dimensions[dimensionIdx];
+}
+
+void CalculateActivationRangeFloat(int32_t activation, float *activation_min, float *activation_max)
+{
+  if (activation == ANEURALNETWORKS_FUSED_RELU)
+  {
+    *activation_min = 0.f;
+    *activation_max = std::numeric_limits<float>::max();
+  }
+  else if (activation == ANEURALNETWORKS_FUSED_RELU6)
+  {
+    *activation_min = 0.f;
+    *activation_max = 6.f;
+  }
+  else if (activation == ANEURALNETWORKS_FUSED_RELU1)
+  {
+    *activation_min = -1.f;
+    *activation_max = 1.f;
+  }
+  else if (activation == ANEURALNETWORKS_FUSED_NONE)
+  {
+    *activation_min = std::numeric_limits<float>::lowest();
+    *activation_max = std::numeric_limits<float>::max();
+  }
+  else
+  {
+    std::cout << "Unsupported fused activation function." << std::endl;
+  }
+}
+
+Shape convertShape(const ::internal::tflite::operand::Shape &o)
+{
+  Shape shape;
+
+  shape.type = static_cast<OperandType>(o.type());
+  shape.dimensions = std::vector<uint32_t>(o.dims().begin(), o.dims().end());
+  shape.scale = o.scale();
+  //shape.offset = _offset;
+
+  return shape;
+}
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.h b/runtimes/new_runtime/src/internal/kernels/cpufallback/OperationUtils.h
new file mode 100644 (file)
index 0000000..5096063
--- /dev/null
@@ -0,0 +1,69 @@
+#ifndef __NNFW_SUPPORT_NNAPI_OPERATION_UTILS_H__
+#define __NNFW_SUPPORT_NNAPI_OPERATION_UTILS_H__
+
+#include <NeuralNetworks.h>
+
+#include <iostream>
+#include <limits>
+#include <vector>
+
+#include "internal/Model.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+enum class OperandType : int32_t {
+    FLOAT32 = 0,
+    INT32 = 1,
+    UINT32 = 2,
+    TENSOR_FLOAT32 = 3,
+    TENSOR_INT32 = 4,
+    TENSOR_QUANT8_ASYMM = 5,
+    OEM = 10000,
+    TENSOR_OEM_BYTE = 10001,
+};
+
+struct Shape {
+    OperandType type;
+    std::vector<uint32_t> dimensions;
+    float scale;
+    int32_t offset;
+};
+
+uint32_t getSizeOfDimension(const Shape &shape, uint32_t dimensionIdx);
+
+inline ::tflite::Dims<4> convertShapeToDims(const Shape& shape) {
+  //nnAssert(shape.dimensions.size() <= 4);
+  ::tflite::Dims<4> dims;
+  // The dimensions are reversed in Dims<4>.
+  for (int i = 0; i < 4; ++i) {
+    int src = static_cast<int>(shape.dimensions.size()) - i - 1;
+    if (src >= 0) {
+      dims.sizes[i] = static_cast<int>(getSizeOfDimension(shape, src));
+    } else {
+      dims.sizes[i] = 1;
+    }
+  }
+  dims.strides[0] = 1;
+  for (int i = 1; i<4; i++) {
+    dims.strides[i] = dims.strides[i-1] * dims.sizes[i-1];
+  }
+  return dims;
+}
+
+void CalculateActivationRangeFloat(int32_t activation,
+                                   float* activation_min,
+                                   float* activation_max);
+
+Shape convertShape(const ::internal::tflite::operand::Shape &o);
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
+
+#endif // __NNFW_SUPPORT_NNAPI_OPERATION_UTILS_H__
diff --git a/runtimes/new_runtime/src/internal/nnapi/kernel/View.h b/runtimes/new_runtime/src/internal/nnapi/kernel/View.h
new file mode 100644 (file)
index 0000000..258f5ac
--- /dev/null
@@ -0,0 +1,73 @@
+#ifndef __INTERNAL_NNAPI_KERNEL_VIEW_H__
+#define __INTERNAL_NNAPI_KERNEL_VIEW_H__
+
+#include "util/kernel/Shape.h"
+#include "util/kernel/Reader.h"
+
+#include <arm_compute/core/ITensor.h>
+
+namespace internal
+{
+namespace nnapi
+{
+namespace kernel
+{
+
+template <typename T> class View final : public nnfw::util::kernel::Reader<float>
+{
+public:
+  View(::arm_compute::ITensor *tensor) : _tensor{tensor}
+  {
+    assert(tensor->info()->data_type() == ::arm_compute::DataType::F32);
+
+    _shape.N = tensor->info()->dimension(3);
+    _shape.C = tensor->info()->dimension(2);
+    _shape.H = tensor->info()->dimension(1);
+    _shape.W = tensor->info()->dimension(0);
+  }
+
+public:
+  const nnfw::util::kernel::Shape &shape(void) const { return _shape; }
+
+public:
+  float at(uint32_t nth, uint32_t row, uint32_t col, uint32_t ch) const override
+  {
+    // NNAPI uses NHWC ordering
+    uint32_t index = 0;
+
+    index += nth * _shape.H * _shape.W * _shape.C;
+    index += row * _shape.W * _shape.C;
+    index += col * _shape.C;
+    index += ch;
+
+    float *ptr = reinterpret_cast<float *>(_tensor->buffer());
+
+    return ptr[index];
+  }
+
+  float &at(uint32_t nth, uint32_t row, uint32_t col, uint32_t ch)
+  {
+    // NNAPI uses NHWC ordering
+    uint32_t index = 0;
+
+    index += nth * _shape.H * _shape.W * _shape.C;
+    index += row * _shape.W * _shape.C;
+    index += col * _shape.C;
+    index += ch;
+
+    float *ptr = reinterpret_cast<float *>(_tensor->buffer());
+
+    return ptr[index];
+  }
+
+private:
+  nnfw::util::kernel::Shape _shape;
+  ::arm_compute::ITensor *_tensor;
+
+};
+
+} // namespace kernel
+} // namespace nnapi
+} // namespace internal
+
+#endif // __INTERNAL_NNAPI_KERNEL_VIEW_H__
index ada6cee..87e79ca 100644 (file)
@@ -32,6 +32,8 @@ int ANeuralNetworksModel_addOperand(ANeuralNetworksModel *model,
     shape.dim(axis) = type->dimensions[axis];
   }
 
+  shape.set(type->type, type->scale);
+
   model->deref().operands().append(shape);
 
   // NOTE We do NOT allocate CLTensor here as we do not how to interpret this one.