Conv2D Operation is only supported yet.
Implement #1607 into the current runtime structure.
Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
file(GLOB_RECURSE SOURCES "src/*.cc")
add_library(${LIB_NEW_RUNTIME} SHARED ${SOURCES})
-target_include_directories(${LIB_NEW_RUNTIME} PUBLIC ${NNFW_INCLUDE_DIR}
- ${CMAKE_SOURCE_DIR}/externals/acl
- ${CMAKE_SOURCE_DIR}/externals/acl/include)
-target_include_directories(${LIB_NEW_RUNTIME} PUBLIC src)
+target_include_directories(${LIB_NEW_RUNTIME} PUBLIC ${NNFW_INCLUDE_DIR})
+target_include_directories(${LIB_NEW_RUNTIME} PUBLIC src
+ ${CMAKE_SOURCE_DIR}/externals/tensorflow)
target_link_libraries(${LIB_NEW_RUNTIME} arm_compute)
+target_link_libraries(${LIB_NEW_RUNTIME} tensorflow-lite)
target_link_libraries(${LIB_NEW_RUNTIME} nnfw_util)
+target_link_libraries(${LIB_NEW_RUNTIME} nnfw_support_nnapi)
set_target_properties(${LIB_NEW_RUNTIME} PROPERTIES OUTPUT_NAME neuralnetworks)
install(TARGETS ${LIB_NEW_RUNTIME} DESTINATION lib/new_runtime)
BackendResolver(::internal::BackendManager &backend_manager)
{
auto acl_gen = backend_manager.get("arm_compute");
+ auto cpu_gen = backend_manager.get("cpu");
// TODO Set generator map according to environment variable
- _gen_map[typeid(::internal::tflite::op::Conv2D::implicit::Node)] = acl_gen;
+ _gen_map[typeid(::internal::tflite::op::Conv2D::implicit::Node)] = cpu_gen;
_gen_map[typeid(::internal::tflite::op::MaxPool2D::implicit::Node)] = acl_gen;
_gen_map[typeid(::internal::tflite::op::AvgPool2D::implicit::Node)] = acl_gen;
_gen_map[typeid(::internal::tflite::op::Concat::Node)] = acl_gen;
#include "internal/nnapi/feature/Reader.h"
#include "internal/nnapi/feature/View.h"
+#include "internal/cpu.h"
+#include <arm_compute/runtime/CL/CLTensor.h>
+
#include "internal/arm_compute/feature/View.h"
#include "util/feature/IndexIterator.h"
public:
void push(::arm_compute::ITensor &tensor) const override
{
- const ::internal::nnapi::feature::Reader<float> from{_shape, _base, _size};
- ::internal::arm_compute::feature::View<float> into{&tensor};
+ // TODO: This is just workaround codes, It needs to refactor.
+ if (typeid(tensor) == typeid(::internal::cpu::Tensor))
+ {
+ const ::internal::nnapi::feature::Reader<float> from{_shape, _base, _size};
+ ::internal::nnapi::feature::View<float> into{_shape, tensor.buffer(), _size};
+
+ ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+ const auto value = from.at(ch, row, col);
+ into.at(ch, row, col) = value;
+ };
+ }
+ else if (typeid(tensor) == typeid(::arm_compute::CLTensor))
+ {
+ const ::internal::nnapi::feature::Reader<float> from{_shape, _base, _size};
+ ::internal::arm_compute::feature::View<float> into{&tensor};
- ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(ch, row, col);
- into.at(ch, row, col) = value;
- };
+ ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+ const auto value = from.at(ch, row, col);
+ into.at(ch, row, col) = value;
+ };
+ }
}
private:
public:
void pull(::arm_compute::ITensor &tensor) const override
{
- const ::internal::arm_compute::feature::View<float> from{&tensor};
- ::internal::nnapi::feature::View<float> into{_shape, _base, _size};
+ // TODO: This is just workaround codes, It needs to refactor.
+ if (typeid(tensor) == typeid(::internal::cpu::Tensor))
+ {
+ const ::internal::nnapi::feature::Reader<float> from{_shape, tensor.buffer(), _size};
+ ::internal::nnapi::feature::View<float> into{_shape, _base, _size};
- ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
- const auto value = from.at(ch, row, col);
- into.at(ch, row, col) = value;
- };
+ ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+ const auto value = from.at(ch, row, col);
+ into.at(ch, row, col) = value;
+ };
+ }
+ else if (typeid(tensor) == typeid(::arm_compute::CLTensor))
+ {
+ const ::internal::arm_compute::feature::View<float> from{&tensor};
+ ::internal::nnapi::feature::View<float> into{_shape, _base, _size};
+
+ ::nnfw::util::feature::iterate(_shape) << [&](uint32_t ch, uint32_t row, uint32_t col) {
+ const auto value = from.at(ch, row, col);
+ into.at(ch, row, col) = value;
+ };
+ }
}
private:
#include "internal/arm_compute/TensorBuilder.h"
#include "internal/arm_compute/InitializerGenerator.h"
#include "internal/arm_compute/StageGenerator.h"
+#include "internal/cpu/TensorBuilder.h"
+#include "internal/cpu/InitializerGenerator.h"
+#include "internal/cpu/StageGenerator.h"
namespace internal
{
_gen_map["arm_compute"] = {acl_initializer_gen, acl_stage_gen};
}
- // TODO Add CPU backend
+ // Add CPU backend
+ {
+ auto cpu_tensor_builder = std::make_shared<::internal::cpu::TensorBuilder>(_plan);
+ auto cpu_initializer_gen = std::make_shared<::internal::cpu::InitializerGenerator>(operands);
+ auto cpu_stage_gen = std::make_shared<::internal::cpu::StageGenerator>(operands, cpu_tensor_builder);
+
+ _gen_map["cpu"] = {cpu_initializer_gen, cpu_stage_gen};
+ }
}
Backend BackendManager::get(const std::string &key)
public:
int32_t dim(uint32_t n) const { return _dims.at(n); }
int32_t &dim(uint32_t n) { return _dims.at(n); }
+ const std::vector<int32_t> &dims() const { return _dims; }
+ int32_t type() const { return _type; }
+ float scale() const { return _scale; }
+ void set(int32_t type, float scale)
+ {
+ _type = type;
+ _scale = scale;
+ }
public:
int32_t asVector(void) const;
private:
std::vector<int32_t> _dims;
+ int32_t _type;
+ float _scale;
};
} // namespace operand
--- /dev/null
+#include "internal/cpu.h"
+
+namespace internal
+{
+namespace cpu
+{
+namespace operand
+{
+
+void Object::access(const std::function<void(::arm_compute::ITensor &tensor)> &fn) const
+{
+ fn(*_tensor);
+}
+
+} // namespace operand
+} // namepsace cpu
+} // namespace internal
--- /dev/null
+#ifndef __INTERNAL_CPU_H__
+#define __INTERNAL_CPU_H__
+
+#include <arm_compute/core/ITensor.h>
+#include <arm_compute/core/TensorInfo.h>
+#include <arm_compute/core/CL/OpenCL.h>
+
+#include <unistd.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <iostream>
+
+#include "internal/IObject.h"
+
+namespace internal
+{
+namespace cpu
+{
+namespace operand
+{
+
+class Object : public ::internal::IObject
+{
+public:
+ Object() = default;
+
+public:
+ Object(const std::shared_ptr<::arm_compute::ITensor> &tensor) : _tensor{tensor}
+ {
+ // DO NOTHING
+ }
+
+public:
+ ::arm_compute::ITensor *ptr(void) const override { return _tensor.get(); }
+
+private:
+ std::shared_ptr<::arm_compute::ITensor> _tensor;
+
+public:
+ void access(const std::function<void(::arm_compute::ITensor &tensor)> &fn) const override;
+};
+
+} // namespace operand
+} // namepsace cpu
+} // namespace internal
+
+
+namespace internal
+{
+namespace cpu
+{
+
+#define PATH_MAX 256
+
+static int shmem_num = 0;
+static int shmem_create_region(size_t size)
+{
+ char temp[PATH_MAX];
+ snprintf(temp, sizeof(temp), "/tmp/nn-shmem-%d-%d-XXXXXXXXX", getpid(), shmem_num++);
+ int fd = mkstemp(temp);
+ if (fd == -1) return -1;
+
+ unlink(temp);
+
+ if (TEMP_FAILURE_RETRY(ftruncate(fd, size)) == -1) {
+ close(fd);
+ return -1;
+ }
+
+ return fd;
+}
+
+class Tensor : public ::arm_compute::ITensor
+{
+public:
+ Tensor() = default;
+
+ Tensor(::arm_compute::TensorInfo info) : _info(info)
+ {
+ uint32_t size = _info.total_size();
+ // TODO Do not use shared memory
+ int fd = shmem_create_region(size);
+ _buffer = reinterpret_cast<uint8_t *>(mmap(0, size, PROT_WRITE | PROT_READ, MAP_PRIVATE, fd, 0));
+ }
+
+ Tensor(uint8_t *buffer) : _buffer(buffer)
+ {
+ // DO NOTHING
+ }
+
+public:
+
+ void setBuffer(uint8_t *buffer)
+ {
+ _buffer = buffer;
+ }
+
+public:
+ ::arm_compute::TensorInfo *info() const override
+ {
+ return const_cast<::arm_compute::TensorInfo*>(&_info);
+ }
+
+ ::arm_compute::TensorInfo *info() override
+ {
+ return &_info;
+ }
+
+ uint8_t *buffer() const override
+ {
+ return _buffer;
+ }
+
+private:
+ ::arm_compute::TensorInfo _info;
+ uint8_t *_buffer = nullptr;
+};
+
+} // cpu
+} // internal
+
+#endif // __INTERNAL_CPU_H__
--- /dev/null
+#include "internal/cpu/InitializerGenerator.h"
+
+#include "internal/nnapi/kernel/Reader.h"
+#include "internal/nnapi/kernel/View.h"
+#include "util/kernel/IndexIterator.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+InitializerGenerator::InitializerGenerator(const ::internal::tflite::operand::Set &ctx) : _ctx(ctx)
+{
+ // DO NOTHING
+}
+
+Initializer InitializerGenerator::generateWeight(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+ const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
+
+ const auto ker_shape = _ctx.at(ker_index).shape().asKernel();
+ auto ker_base = _ctx.at(ker_index).data().base();
+ auto ker_size = _ctx.at(ker_index).data().size();
+
+ return [ker_shape, ker_base, ker_size](::arm_compute::ITensor &tensor) {
+ const ::internal::nnapi::kernel::Reader<float> from{ker_shape, ker_base, ker_size};
+ ::internal::nnapi::kernel::View<float> into{&tensor};
+
+ ::nnfw::util::kernel::iterate(ker_shape)
+ << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
+ const auto value = from.at(nth, ch, row, col);
+ into.at(nth, row, col, ch) = value;
+ };
+ };
+}
+
+Initializer InitializerGenerator::generateWeight(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+Initializer InitializerGenerator::generateBias(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+ // TODO Refactor so we can reuse the common code
+
+ const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+ auto bias_base = _ctx.at(bias_index).data().base();
+ const auto bias_size = _ctx.at(bias_index).shape().asVector();
+
+ return [bias_base, bias_size](::arm_compute::ITensor &tensor) {
+ for (uint32_t n = 0; n < bias_size; ++n)
+ {
+ const ::arm_compute::Coordinates coordinate{n};
+
+ float *into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
+
+ const float *from = reinterpret_cast<const float *>(bias_base) + n;
+ const auto value = *from;
+
+ *into = value;
+ }
+ };
+}
+
+Initializer InitializerGenerator::generateBias(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+} // namespace arm_compute
+} // namespace internal
--- /dev/null
+#ifndef __INTERNAL_CPU_INITIALIZER_GENERATOR_H__
+#define __INTERNAL_CPU_INITIALIZER_GENERATOR_H__
+
+#include "internal/IInitializerGenerator.h"
+
+#include "internal/Model.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+class InitializerGenerator : public ::internal::IInitializerGenerator
+{
+public:
+ InitializerGenerator(const ::internal::tflite::operand::Set &ctx);
+
+ Initializer generateWeight(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+ Initializer generateWeight(const ::internal::tflite::op::FullyConnected::Node &node) override;
+
+ Initializer generateBias(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+ Initializer generateBias(const ::internal::tflite::op::FullyConnected::Node &node) override;
+
+private:
+ const ::internal::tflite::operand::Set &_ctx;
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_INITIALIZER_GENERATOR_H__
--- /dev/null
+//#include "internal/cpu/MemoryAllocator.h"
+
--- /dev/null
+#ifndef __INTERNAL_CPU_MEMORY_ALLOCATOR_H__
+#define __INTERNAL_CPU_MEMORY_ALLOCATOR_H__
+
+#include "arm_compute/runtime/ITensorAllocator.h"
+#include "arm_compute/runtime/Memory.h"
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+namespace arm_compute
+{
+class Coordinates;
+class TensorInfo;
+class Tensor;
+};
+
+/** Basic implementation of a CPU memory tensor allocator. */
+class TensorAllocator : public ITensorAllocator
+{
+public:
+ /** Default constructor. */
+ TensorAllocator(Tensor *owner = nullptr);
+ /** Default destructor */
+ ~TensorAllocator();
+
+ /** Make ITensorAllocator's init methods available */
+ using ITensorAllocator::init;
+
+ /** Shares the same backing memory with another tensor allocator, while the tensor info might be different.
+ * In other words this can be used to create a sub-tensor from another tensor while sharing the same memory.
+ *
+ * @note TensorAllocator have to be of the same specialized type.
+ *
+ * @param[in] allocator The allocator that owns the backing memory to be shared. Ownership becomes shared afterwards.
+ * @param[in] coords The starting coordinates of the new tensor inside the parent tensor.
+ * @param[in] sub_info The new tensor information (e.g. shape etc)
+ */
+ void init(const TensorAllocator &allocator, const Coordinates &coords, TensorInfo sub_info);
+
+ /** Returns the pointer to the allocated data. */
+ uint8_t *data() const;
+
+ /** Allocate size specified by TensorInfo of CPU memory.
+ *
+ * @note The tensor must not already be allocated when calling this function.
+ *
+ */
+ void allocate() override;
+
+ /** Free allocated CPU memory.
+ *
+ * @note The tensor must have been allocated when calling this function.
+ *
+ */
+ void free() override;
+ /** Import an existing memory as a tensor's backing memory
+ *
+ * @warning If the tensor is flagged to be managed by a memory manager,
+ * this call will lead to an error.
+ * @warning Ownership of memory depends on the way the @ref Memory object was constructed
+ * @note Calling free on a tensor with imported memory will just clear
+ * the internal pointer value.
+ *
+ * @param[in] memory Memory to import
+ *
+ * @return error status
+ */
+ arm_compute::Status import_memory(Memory memory);
+ /** Associates the tensor with a memory group
+ *
+ * @param[in] associated_memory_group Memory group to associate the tensor with
+ */
+ void set_associated_memory_group(MemoryGroup *associated_memory_group);
+
+protected:
+ /** No-op for CPU memory
+ *
+ * @return A pointer to the beginning of the tensor's allocation.
+ */
+ uint8_t *lock() override;
+
+ /** No-op for CPU memory. */
+ void unlock() override;
+
+private:
+ MemoryGroup *_associated_memory_group; /**< Registered memory manager */
+ Memory _memory; /**< CPU memory */
+ Tensor *_owner; /**< Owner of the allocator */
+};
+
+namespace internal
+{
+namespace cpu
+{
+
+class MemoryAllocator : public {
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_MEMORY_ALLOCATOR_H__
--- /dev/null
+#include "internal/cpu/StageGenerator.h"
+
+#include <stdexcept>
+
+#include "internal/Padding.h"
+#include "internal/kernels/cpufallback/CPUConvolutionLayer.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+StageGenerator::StageGenerator(const ::internal::tflite::operand::Set &operand_ctx,
+ const std::shared_ptr<::internal::cpu::TensorBuilder> &tensor_builder)
+ : _ctx(operand_ctx), _tensor_builder(tensor_builder)
+{
+ // DO NOTHING
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+ const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+ const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+ const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
+ const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+ const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
+ const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
+
+ const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
+ const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
+
+ const PaddingCode padding_type =
+ static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
+
+ assert((ANEURALNETWORKS_PADDING_SAME == padding_type) ||
+ (ANEURALNETWORKS_PADDING_VALID == padding_type));
+
+ Stride stride;
+
+ stride.vertical = _ctx.at(vstride_index).asScalar<int32_t>();
+ stride.horizontal = _ctx.at(hstride_index).asScalar<int32_t>();
+
+ // Construct operation parameters
+ struct Param
+ {
+ int ofm_index;
+ int ifm_index;
+ int ker_index;
+ int bias_index;
+
+ ::internal::tflite::operand::Shape ofm_shape{1};
+ ::internal::tflite::operand::Shape ifm_shape{1};
+ ::internal::tflite::operand::Shape ker_shape{1};
+ ::internal::tflite::operand::Shape bias_shape{1};
+
+ Padding padding;
+ Stride stride;
+
+ FuseCode activation;
+ };
+
+ Param param;
+
+ param.ofm_index = ofm_index.asInt();
+ param.ifm_index = ifm_index.asInt();
+ param.ker_index = ker_index.asInt();
+ param.bias_index = bias_index.asInt();
+
+ param.ofm_shape = _ctx.at(ofm_index).shape();
+ param.ifm_shape = _ctx.at(ifm_index).shape();
+ param.ker_shape = _ctx.at(ker_index).shape();
+ param.bias_shape = _ctx.at(bias_index).shape();
+
+ param.stride = stride;
+ param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
+ ? same_padding(param.ifm_shape.asFeature(), param.ofm_shape.asFeature(),
+ stride, param.ker_shape.asKernel().W,
+ param.ker_shape.asKernel().H)
+ : valid_padding();
+
+ param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScalar<int32_t>());
+
+ auto tensors = _tensor_builder;
+
+ return [tensors, param](IExecutionBuilder &builder) {
+ auto ofm_alloc = tensors->at(::internal::tflite::operand::Index{param.ofm_index});
+ auto ifm_alloc = tensors->at(::internal::tflite::operand::Index{param.ifm_index});
+ auto ker_alloc = tensors->at(::internal::tflite::operand::Index{param.ker_index});
+ auto bias_alloc = tensors->at(::internal::tflite::operand::Index{param.bias_index});
+
+ std::unique_ptr<::internal::kernels::cpu::CPUConvolutionLayer> fn{
+ new ::internal::kernels::cpu::CPUConvolutionLayer};
+
+ fn->configure(ifm_alloc->buffer(), param.ifm_shape, ker_alloc->buffer(), param.ker_shape, bias_alloc->buffer(),
+ param.bias_shape, param.padding.left, param.padding.right, param.padding.top, param.padding.bottom,
+ param.stride.horizontal, param.stride.vertical, param.activation, ofm_alloc->buffer(),
+ param.ofm_shape);
+
+ builder.append(std::move(fn));
+ };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::MaxPool2D::implicit::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::AvgPool2D::implicit::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Concat::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Reshape::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Softmax::Node &node)
+{
+ throw std::runtime_error("NYI");
+}
+
+} // namespace stage
+} // namespace internal
--- /dev/null
+#ifndef __INTERNAL_CPU_STAGE_GENERATOR_H__
+#define __INTERNAL_CPU_STAGE_GENERATOR_H__
+
+#include "internal/IStageGenerator.h"
+
+#include "internal/Model.h"
+#include "internal/cpu.h"
+#include "internal/cpu/TensorBuilder.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+class StageGenerator : public ::internal::IStageGenerator
+{
+public:
+ StageGenerator(const ::internal::tflite::operand::Set &ctx,
+ const std::shared_ptr<::internal::cpu::TensorBuilder> &tensor_builder);
+
+ virtual std::shared_ptr<ITensorBuilder> tensor_builder() override { return _tensor_builder; }
+
+ virtual Stage generate(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+ virtual Stage generate(const ::internal::tflite::op::MaxPool2D::implicit::Node &node) override;
+ virtual Stage generate(const ::internal::tflite::op::AvgPool2D::implicit::Node &node) override;
+ virtual Stage generate(const ::internal::tflite::op::Concat::Node &node) override;
+ virtual Stage generate(const ::internal::tflite::op::FullyConnected::Node &node) override;
+ virtual Stage generate(const ::internal::tflite::op::Reshape::Node &node) override;
+ virtual Stage generate(const ::internal::tflite::op::Softmax::Node &node) override;
+
+private:
+ const ::internal::tflite::operand::Set &_ctx;
+ std::shared_ptr<::internal::cpu::TensorBuilder> _tensor_builder;
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_STAGE_GENERATOR_H__
--- /dev/null
+#include "internal/cpu/TensorBuilder.h"
+
+#include <cassert>
+
+#include "internal/arm_compute.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+TensorBuilder::TensorBuilder(::internal::arm_compute::Plan &plan) : _plan(plan)
+{
+ // DO NOTHING
+}
+
+void TensorBuilder::mark(const ::internal::tflite::operand::Index& ind)
+{
+ assert(_tensors.size() == 0);
+
+ _inds.insert(ind.asInt());
+}
+
+void TensorBuilder::prepare(const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx)
+{
+ assert(_tensors.size() == 0);
+
+ for (auto ind_int : _inds)
+ {
+ ::internal::tflite::operand::Index ind{ind_int};
+ auto tensor = std::make_shared<::internal::cpu::Tensor>(tensor_info_ctx.at(ind.asInt()));
+ // TODO Fix allocation here. When Tensor object is created the memory for tensor is also
+ // allocated, and this must be fixed.
+ _plan.operands().set(ind, std::make_shared<::internal::cpu::operand::Object>(tensor));
+ _tensors[ind.asInt()] = tensor;
+ }
+}
+
+void TensorBuilder::allocate(void)
+{
+ assert(_inds.size() == _tensors.size());
+
+ // NOTE For now nothing to do. Allocation is done in prepare stage, which is wrong
+ // See also: comment in `prepare()`
+}
+
+std::shared_ptr<::internal::cpu::Tensor> TensorBuilder::at(const ::internal::tflite::operand::Index &ind)
+{
+ return _tensors.at(ind.asInt());
+}
+
+} // namespace cpu
+} // namespace internal
--- /dev/null
+#ifndef __INTERNAL_CPU_TENSOR_BUILDER_H__
+#define __INTERNAL_CPU_TENSOR_BUILDER_H__
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "internal/ITensorBuilder.h"
+#include "internal/cpu.h"
+#include "internal/arm_compute.h"
+
+namespace internal
+{
+namespace cpu
+{
+
+class Plan;
+
+class TensorBuilder : public ::internal::ITensorBuilder
+{
+public:
+ TensorBuilder(::internal::arm_compute::Plan &plan);
+
+ virtual void mark(const ::internal::tflite::operand::Index& ind) override;
+ virtual void prepare(const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx) override;
+ virtual void allocate(void) override;
+
+ std::shared_ptr<::internal::cpu::Tensor> at(const ::internal::tflite::operand::Index &ind);
+
+private:
+ ::internal::arm_compute::Plan &_plan;
+ std::unordered_set<int> _inds;
+ std::unordered_map<int, std::shared_ptr<::internal::cpu::Tensor>> _tensors;
+};
+
+} // namespace cpu
+} // namespace internal
+
+#endif // __INTERNAL_CPU_TENSOR_BUILDER_H__
--- /dev/null
+#include "CPUConvolutionLayer.h"
+
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+// If possible we will use this static buffer for the tensor.
+static constexpr int kStaticBufferSize = 1605632;
+static char static_scratch_buffer[kStaticBufferSize];
+
+#define ANDROID_NN_CONV_PARAMETERS(Type) \
+ uint32_t height = getSizeOfDimension(_inputShape, 1); \
+ uint32_t width = getSizeOfDimension(_inputShape, 2); \
+ uint32_t kernelHeight = getSizeOfDimension(_kernelShape, 1); \
+ uint32_t kernelWidth = getSizeOfDimension(_kernelShape, 2); \
+ uint32_t outHeight = getSizeOfDimension(_outputShape, 1); \
+ uint32_t outWidth = getSizeOfDimension(_outputShape, 2); \
+ uint32_t inDepth = getSizeOfDimension(_inputShape, 3); \
+ \
+ uint32_t paddingHeight = (uint32_t)_paddingTop; \
+ uint32_t paddingWidth = (uint32_t)_paddingLeft; \
+ \
+ ::tflite::Dims<4> im2colDim; \
+ im2colDim.sizes[3] = (int)getSizeOfDimension(_outputShape, 0); \
+ im2colDim.sizes[2] = (int)getSizeOfDimension(_outputShape, 1); \
+ im2colDim.sizes[1] = (int)getSizeOfDimension(_outputShape, 2); \
+ im2colDim.sizes[0] = (int)inDepth * kernelHeight * kernelWidth; \
+ \
+ im2colDim.strides[0] = 1; \
+ for (int i=1; i<4; i++) { \
+ im2colDim.strides[i] = im2colDim.strides[i-1] * im2colDim.sizes[i-1]; \
+ } \
+ Type* im2colData = nullptr; \
+ uint64_t im2colByteSize = sizeof(Type); \
+ std::unique_ptr<Type[]> im2colGuard; \
+ for (int i=0; i<4; i++) { \
+ im2colByteSize *= im2colDim.sizes[i]; \
+ } \
+ /* http://b/77982879, tflite::optimized_ops::Conv uses int for offsets */ \
+ if (im2colByteSize >= 0x7fffffff) { \
+ std::cout << "Conv size is too large, not enough memory" << std::endl; \
+ return false; \
+ } \
+ if (im2colByteSize <= kStaticBufferSize) { \
+ im2colData = reinterpret_cast<Type *>(static_scratch_buffer); \
+ } else { \
+ im2colData = new (std::nothrow) Type[im2colByteSize / sizeof(Type)]; \
+ if (im2colData == nullptr) { \
+ std::cout << "Conv size is too large, not enough memory" << std::endl; \
+ return false; \
+ } \
+ im2colGuard.reset(im2colData); \
+ }
+
+bool CPUConvolutionLayer::convFloat32()
+{
+ ANDROID_NN_CONV_PARAMETERS(float)
+
+ float output_activation_min, output_activation_max;
+ CalculateActivationRangeFloat(_activation, &output_activation_min,
+ &output_activation_max);
+ int32_t dilationWidthFactor = 1, dilationHeightFactor = 1;
+ ::tflite::optimized_ops::Conv(
+ reinterpret_cast<const float *>(_inputData), convertShapeToDims(_inputShape),
+ reinterpret_cast<const float *>(_kernelData), convertShapeToDims(_kernelShape),
+ reinterpret_cast<const float *>(_biasData), convertShapeToDims(_biasShape),
+ _strideWidth, _strideHeight,
+ dilationWidthFactor, dilationHeightFactor,
+ paddingWidth, paddingHeight,
+ output_activation_min, output_activation_max,
+ reinterpret_cast<float *>(_outputData), convertShapeToDims(_outputShape),
+ im2colData, im2colDim);
+ return true;
+}
+
+bool CPUConvolutionLayer::convQuant8()
+{
+/*
+ ANDROID_NN_CONV_PARAMETERS(uint8_t)
+
+ int32_t inputOffset = -inputShape.offset;
+ int32_t filterOffset = -filterShape.offset;
+ int32_t outputOffset = outputShape.offset;
+
+ float real_multiplier = 0.0;
+ int32_t output_multiplier = 0;
+ int32_t output_shift = 0;
+ int32_t output_activation_min = 0;
+ int32_t output_activation_max = 0;
+
+ if (!GetQuantizedConvolutionMultipler(inputShape, filterShape, biasShape, outputShape,
+ &real_multiplier) ||
+ !QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift))
+ {
+ // Following code inserted to resolve Coverity (118950 Resource leak)
+ if (im2colByteSize > kStaticBufferSize)
+ {
+ delete[] im2colData;
+ }
+ return false;
+ }
+ CalculateActivationRangeUint8(activation, outputShape, &output_activation_min,
+ &output_activation_max);
+
+ static gemmlowp::GemmContext gemm_context;
+ // Alow gemmlowp automatcally decide how many threads to use.
+ gemm_context.set_max_num_threads(0);
+
+#define ANDROID_NN_CONV(activation) \
+ optimized_ops::Conv<FusedActivationFunctionType::activation>( \
+ inputData, convertShapeToDims(inputShape), inputOffset, filterData, \
+ convertShapeToDims(filterShape), filterOffset, biasData, convertShapeToDims(biasShape), \
+ stride_width, stride_height, paddingWidth, paddingHeight, outputOffset, output_multiplier, \
+ output_shift, output_activation_min, output_activation_max, outputData, \
+ convertShapeToDims(outputShape), im2colData, im2colDim, &gemm_context)
+
+ ANDROID_NN_MACRO_DISPATCH_WITH_DELETE(ANDROID_NN_CONV)
+#undef ANDROID_NN_CONV
+
+ if (im2colByteSize > kStaticBufferSize)
+ {
+ delete[] im2colData;
+ }*/
+ return true;
+}
+
+void CPUConvolutionLayer::configure(uint8_t *inputData, const internal::tflite::operand::Shape inputShape, uint8_t *kernelData,
+ const internal::tflite::operand::Shape kernelShape, uint8_t *biasData, const internal::tflite::operand::Shape biasShape,
+ const uint32_t paddingLeft, const uint32_t paddingRight, const uint32_t paddingTop, const uint32_t paddingBottom,
+ const uint32_t strideWidth, const uint32_t strideHeight, const FuseCode activation, uint8_t *outputData, const internal::tflite::operand::Shape outputShape)
+{
+ _inputData = inputData;
+ _inputShape = convertShape(inputShape);
+ _kernelData = kernelData;
+ _kernelShape = convertShape(kernelShape);
+ _biasData = biasData;
+ _biasShape = convertShape(biasShape);
+ _paddingLeft = paddingLeft;
+ _paddingRight = paddingRight;
+ _paddingTop = paddingTop;
+ _paddingBottom = paddingBottom;
+ _strideWidth = strideWidth;
+ _strideHeight = strideHeight;
+ _activation = activation;
+ _outputData = outputData;
+ _outputShape = convertShape(outputShape);
+}
+
+
+void CPUConvolutionLayer::run()
+{
+ convFloat32();
+ /*
+ if (input.type == OperandType::TENSOR_FLOAT32)
+ {
+ }
+ else if (input.type == OperandType::TENSOR_QUANT8_ASYMM)
+ {
+ }
+ */
+}
+
+#undef ANDROID_NN_CONV_PARAMETERS
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
--- /dev/null
+#ifndef __INTERNAL_KERNELS_CPU_CPUCONVOLUTIONLAYER_H__
+#define __INTERNAL_KERNELS_CPU_CPUCONVOLUTIONLAYER_H__
+
+#include <NeuralNetworks.h>
+
+#include <arm_compute/runtime/IFunction.h>
+
+#include "internal/Model.h"
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+using namespace internal::kernels::cpu;
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+class CPUConvolutionLayer : public ::arm_compute::IFunction
+{
+public:
+ CPUConvolutionLayer()
+ {
+ }
+
+public:
+ bool convFloat32();
+
+ bool convQuant8();
+
+ void configure(uint8_t *inputData, const internal::tflite::operand::Shape inputShape, uint8_t *kernelData,
+ const internal::tflite::operand::Shape kernelShape, uint8_t *biasData, const internal::tflite::operand::Shape biasShape,
+ const uint32_t paddingLeft, const uint32_t paddingRight, const uint32_t paddingTop, const uint32_t paddingBottom,
+ const uint32_t strideW, const uint32_t strideH, const FuseCode activation, uint8_t *outputData, const internal::tflite::operand::Shape outputShape);
+
+ void run();
+
+private:
+ uint8_t *_inputData;
+ uint8_t *_kernelData;
+ uint8_t *_outputData;
+ uint8_t *_biasData;
+
+ Shape _inputShape;
+ Shape _kernelShape;
+ Shape _outputShape;
+ Shape _biasShape;
+
+ uint32_t _paddingLeft;
+ uint32_t _paddingTop;
+ uint32_t _paddingRight;
+ uint32_t _paddingBottom;
+
+ uint32_t _strideWidth;
+ uint32_t _strideHeight;
+
+ FuseCode _activation;
+};
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
+
+#endif // __INTERNAL_KERNELS_CPU_CPUCONVOLUTIONLAYER_H__
--- /dev/null
+#include "internal/kernels/cpufallback/OperationUtils.h"
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+uint32_t getSizeOfDimension(const Shape &shape, uint32_t dimensionIdx)
+{
+ if (dimensionIdx >= shape.dimensions.size())
+ {
+ // TODO, log the error
+ return 0;
+ }
+ return shape.dimensions[dimensionIdx];
+}
+
+void CalculateActivationRangeFloat(int32_t activation, float *activation_min, float *activation_max)
+{
+ if (activation == ANEURALNETWORKS_FUSED_RELU)
+ {
+ *activation_min = 0.f;
+ *activation_max = std::numeric_limits<float>::max();
+ }
+ else if (activation == ANEURALNETWORKS_FUSED_RELU6)
+ {
+ *activation_min = 0.f;
+ *activation_max = 6.f;
+ }
+ else if (activation == ANEURALNETWORKS_FUSED_RELU1)
+ {
+ *activation_min = -1.f;
+ *activation_max = 1.f;
+ }
+ else if (activation == ANEURALNETWORKS_FUSED_NONE)
+ {
+ *activation_min = std::numeric_limits<float>::lowest();
+ *activation_max = std::numeric_limits<float>::max();
+ }
+ else
+ {
+ std::cout << "Unsupported fused activation function." << std::endl;
+ }
+}
+
+Shape convertShape(const ::internal::tflite::operand::Shape &o)
+{
+ Shape shape;
+
+ shape.type = static_cast<OperandType>(o.type());
+ shape.dimensions = std::vector<uint32_t>(o.dims().begin(), o.dims().end());
+ shape.scale = o.scale();
+ //shape.offset = _offset;
+
+ return shape;
+}
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
--- /dev/null
+#ifndef __NNFW_SUPPORT_NNAPI_OPERATION_UTILS_H__
+#define __NNFW_SUPPORT_NNAPI_OPERATION_UTILS_H__
+
+#include <NeuralNetworks.h>
+
+#include <iostream>
+#include <limits>
+#include <vector>
+
+#include "internal/Model.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace internal
+{
+namespace kernels
+{
+namespace cpu
+{
+
+enum class OperandType : int32_t {
+ FLOAT32 = 0,
+ INT32 = 1,
+ UINT32 = 2,
+ TENSOR_FLOAT32 = 3,
+ TENSOR_INT32 = 4,
+ TENSOR_QUANT8_ASYMM = 5,
+ OEM = 10000,
+ TENSOR_OEM_BYTE = 10001,
+};
+
+struct Shape {
+ OperandType type;
+ std::vector<uint32_t> dimensions;
+ float scale;
+ int32_t offset;
+};
+
+uint32_t getSizeOfDimension(const Shape &shape, uint32_t dimensionIdx);
+
+inline ::tflite::Dims<4> convertShapeToDims(const Shape& shape) {
+ //nnAssert(shape.dimensions.size() <= 4);
+ ::tflite::Dims<4> dims;
+ // The dimensions are reversed in Dims<4>.
+ for (int i = 0; i < 4; ++i) {
+ int src = static_cast<int>(shape.dimensions.size()) - i - 1;
+ if (src >= 0) {
+ dims.sizes[i] = static_cast<int>(getSizeOfDimension(shape, src));
+ } else {
+ dims.sizes[i] = 1;
+ }
+ }
+ dims.strides[0] = 1;
+ for (int i = 1; i<4; i++) {
+ dims.strides[i] = dims.strides[i-1] * dims.sizes[i-1];
+ }
+ return dims;
+}
+
+void CalculateActivationRangeFloat(int32_t activation,
+ float* activation_min,
+ float* activation_max);
+
+Shape convertShape(const ::internal::tflite::operand::Shape &o);
+
+} // namespace cpu
+} // namespace kernels
+} // namespace internal
+
+#endif // __NNFW_SUPPORT_NNAPI_OPERATION_UTILS_H__
--- /dev/null
+#ifndef __INTERNAL_NNAPI_KERNEL_VIEW_H__
+#define __INTERNAL_NNAPI_KERNEL_VIEW_H__
+
+#include "util/kernel/Shape.h"
+#include "util/kernel/Reader.h"
+
+#include <arm_compute/core/ITensor.h>
+
+namespace internal
+{
+namespace nnapi
+{
+namespace kernel
+{
+
+template <typename T> class View final : public nnfw::util::kernel::Reader<float>
+{
+public:
+ View(::arm_compute::ITensor *tensor) : _tensor{tensor}
+ {
+ assert(tensor->info()->data_type() == ::arm_compute::DataType::F32);
+
+ _shape.N = tensor->info()->dimension(3);
+ _shape.C = tensor->info()->dimension(2);
+ _shape.H = tensor->info()->dimension(1);
+ _shape.W = tensor->info()->dimension(0);
+ }
+
+public:
+ const nnfw::util::kernel::Shape &shape(void) const { return _shape; }
+
+public:
+ float at(uint32_t nth, uint32_t row, uint32_t col, uint32_t ch) const override
+ {
+ // NNAPI uses NHWC ordering
+ uint32_t index = 0;
+
+ index += nth * _shape.H * _shape.W * _shape.C;
+ index += row * _shape.W * _shape.C;
+ index += col * _shape.C;
+ index += ch;
+
+ float *ptr = reinterpret_cast<float *>(_tensor->buffer());
+
+ return ptr[index];
+ }
+
+ float &at(uint32_t nth, uint32_t row, uint32_t col, uint32_t ch)
+ {
+ // NNAPI uses NHWC ordering
+ uint32_t index = 0;
+
+ index += nth * _shape.H * _shape.W * _shape.C;
+ index += row * _shape.W * _shape.C;
+ index += col * _shape.C;
+ index += ch;
+
+ float *ptr = reinterpret_cast<float *>(_tensor->buffer());
+
+ return ptr[index];
+ }
+
+private:
+ nnfw::util::kernel::Shape _shape;
+ ::arm_compute::ITensor *_tensor;
+
+};
+
+} // namespace kernel
+} // namespace nnapi
+} // namespace internal
+
+#endif // __INTERNAL_NNAPI_KERNEL_VIEW_H__
shape.dim(axis) = type->dimensions[axis];
}
+ shape.set(type->type, type->scale);
+
model->deref().operands().append(shape);
// NOTE We do NOT allocate CLTensor here as we do not how to interpret this one.