[new_runtime] Initial structure change (#1843)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Wed, 4 Jul 2018 05:31:01 +0000 (14:31 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 4 Jul 2018 05:31:01 +0000 (14:31 +0900)
Change runtime structure to support multiple backends

- Introduce `Backend` which contains IInitializerGenerator and
  IStageGenerator
- Introduce `BackendResolver` so we can choose which backend for each
  operation
- Introduce `IObject` so it can access any kind of tensors from different
  backends
- Move arm_compute-specific implementation to arm_compute directory
- Keep Concat layer for later elimination since it is hard to know at
  first if we have heterogeneous backends

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
17 files changed:
runtimes/new_runtime/src/compilation.cc
runtimes/new_runtime/src/internal/BackendManager.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/BackendManager.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/IInitializerGenerator.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/IObject.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/IStageGenerator.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/ITensorBuilder.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/Padding.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/Padding.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/arm_compute.cc
runtimes/new_runtime/src/internal/arm_compute.h
runtimes/new_runtime/src/internal/arm_compute/InitializerGenerator.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/arm_compute/InitializerGenerator.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/arm_compute/StageGenerator.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/arm_compute/StageGenerator.h [new file with mode: 0644]
runtimes/new_runtime/src/internal/arm_compute/TensorBuilder.cc [new file with mode: 0644]
runtimes/new_runtime/src/internal/arm_compute/TensorBuilder.h [new file with mode: 0644]

index 53f87c4..9e13729 100644 (file)
 #include <NeuralNetworks.h>
 
+#include <typeindex>
+
 #include <arm_compute/core/CL/ICLTensor.h>
 
 #include <arm_compute/runtime/IFunction.h>
 #include <arm_compute/runtime/CL/CLScheduler.h>
-#include <arm_compute/runtime/CL/CLSubTensor.h>
-#include <arm_compute/runtime/CL/functions/CLPoolingLayer.h>
-#include <arm_compute/runtime/CL/functions/CLActivationLayer.h>
-#include <arm_compute/runtime/CL/functions/CLReshapeLayer.h>
-#include <arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h>
-#include <arm_compute/runtime/CL/functions/CLSoftmaxLayer.h>
 
 #include "internal/arm_compute/kernel/View.h"
+#include "internal/arm_compute/TensorBuilder.h"
 #include "internal/nnapi/kernel/Reader.h"
-
-#include "util/kernel/IndexIterator.h"
+#include "internal/Model.h"
+#include "internal/Padding.h"
+#include "internal/IInitializerGenerator.h"
+#include "internal/IStageGenerator.h"
 
 #include "compilation.h"
 #include "model.h"
 #include "logging.h"
 
-const char *to_string(const PaddingCode &code)
-{
-  assert((ANEURALNETWORKS_PADDING_SAME == code) || (ANEURALNETWORKS_PADDING_VALID == code));
-
-  switch (code)
-  {
-    case ANEURALNETWORKS_PADDING_SAME:
-      return "ANEURALNETWORKS_PADDING_SAME";
-    case ANEURALNETWORKS_PADDING_VALID:
-      return "ANEURALNETWORKS_PADDING_VALID";
-  }
-
-  return nullptr;
-}
-
-struct Padding
-{
-  uint32_t top;
-  uint32_t bottom;
-  uint32_t left;
-  uint32_t right;
-};
-
-struct Stride
-{
-  uint32_t vertical;
-  uint32_t horizontal;
-};
-
-Padding valid_padding(void)
-{
-  //
-  // ANEURALNETWORKS_PADDING_VALID
-  //
-  // VALID padding. No padding.
-  //
-  // When the input size is not evenly divisible by the filter size,
-  // the input at the end that could not fill the whole filter tile
-  // will simply be ignored.
-  //
-  Padding padding;
-
-  padding.top = 0;
-  padding.bottom = 0;
-  padding.left = 0;
-  padding.right = 0;
-
-  return padding;
-}
-
-Padding same_padding(const nnfw::util::feature::Shape &ifm_shape,
-                     const nnfw::util::feature::Shape &ofm_shape, const Stride &stride, uint32_t kw,
-                     uint32_t kh)
-{
-  Padding padding;
-
-  // ANEURALNETWORKS_PADDING_SAME (from NNAPI spec)
-  //
-  // SAME padding. Padding on both ends are the "same":
-  //
-  //   padding_to_beginning = total_padding / 2
-  //  padding_to_end = (total_padding + 1)/2.
-  //
-  const int32_t vertical_needed_input = (ofm_shape.H - 1) * stride.vertical + kh;
-  const int32_t vertical_total_padding = std::max(0, vertical_needed_input - ifm_shape.H);
-
-  const int32_t horizontal_needed_input = (ofm_shape.W - 1) * stride.horizontal + kw;
-  const int32_t horizontal_total_padding = std::max(0, horizontal_needed_input - ifm_shape.W);
-
-  padding.top = vertical_total_padding / 2;
-  padding.bottom = (vertical_total_padding + 1) / 2;
-  padding.left = horizontal_total_padding / 2;
-  padding.right = (horizontal_total_padding + 1) / 2;
-
-  return padding;
-}
-
-template <typename T> std::unique_ptr<T> make_layer(void) { return std::unique_ptr<T>{new T}; }
-
 ::arm_compute::TensorShape asTensorShape(int32_t h, int32_t w)
 {
   return ::arm_compute::TensorShape(w, h);
@@ -136,33 +56,7 @@ template <typename T> std::unique_ptr<T> make_layer(void) { return std::unique_p
                                    ::arm_compute::DataType::F32);
 }
 
-::arm_compute::PadStrideInfo asPadStringInfo(const Padding &padding, const Stride &stride)
-{
-  return ::arm_compute::PadStrideInfo{stride.horizontal,
-                                      stride.vertical,
-                                      padding.left,
-                                      padding.right,
-                                      padding.top,
-                                      padding.bottom,
-                                      ::arm_compute::DimensionRoundingType::FLOOR};
-}
-
-struct IAllocationContext
-{
-  virtual ~IAllocationContext() = default;
-
-  virtual ::arm_compute::ICLTensor *at(const ::internal::tflite::operand::Index &ind) const = 0;
-};
-
-struct IExecutionBuilder
-{
-  virtual ~IExecutionBuilder() = default;
-
-  virtual void append(std::unique_ptr<::arm_compute::IFunction> &&f) = 0;
-};
-
-using Initializer = std::function<void(::arm_compute::ITensor &)>;
-using Stage = std::function<void(const IAllocationContext &, IExecutionBuilder &)>;
+using TensorSetter = std::function<void(int, const ::arm_compute::TensorInfo &)>;
 
 struct IPlanBuilder
 {
@@ -170,78 +64,75 @@ struct IPlanBuilder
 
   virtual void addShapeConstr(const ::internal::tflite::operand::Index &ind,
                               const ::arm_compute::TensorInfo &info) = 0;
-  virtual void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
-                                    const ::internal::tflite::operand::Index &base,
-                                    const ::arm_compute::Coordinates &offset,
-                                    const ::arm_compute::TensorShape &shape) = 0;
   virtual void addInitializer(const ::internal::tflite::operand::Index &ind,
                               const Initializer &initializer) = 0;
   virtual void addStage(const Stage &) = 0;
 };
 
-//
-// ActivationBuilder
-//
-class ActivationBuilder
+#include "internal/BackendManager.h"
+
+class BackendResolver
 {
 public:
-  ActivationBuilder(IExecutionBuilder &builder) : _builder(builder)
+  BackendResolver(::internal::BackendManager &backend_manager)
   {
-    // DO NOTHING
+    auto acl_gen = backend_manager.get("arm_compute");
+
+    // TODO Set generator map according to environment variable
+    _gen_map[typeid(::internal::tflite::op::Conv2D::implicit::Node)] = acl_gen;
+    _gen_map[typeid(::internal::tflite::op::MaxPool2D::implicit::Node)] = acl_gen;
+    _gen_map[typeid(::internal::tflite::op::AvgPool2D::implicit::Node)] = acl_gen;
+    _gen_map[typeid(::internal::tflite::op::Concat::Node)] = acl_gen;
+    _gen_map[typeid(::internal::tflite::op::FullyConnected::Node)] = acl_gen;
+    _gen_map[typeid(::internal::tflite::op::Reshape::Node)] = acl_gen;
+    _gen_map[typeid(::internal::tflite::op::Softmax::Node)] = acl_gen;
   }
 
-private:
-  void appendReLU(::arm_compute::ICLTensor *tensor);
-
-public:
-  void append(FuseCode code, ::arm_compute::ICLTensor *tensor);
+  std::shared_ptr<::internal::IInitializerGenerator> getInitializerGenerator(const std::type_index &type);
+  std::shared_ptr<::internal::IStageGenerator> getStageGenerator(const std::type_index &type);
+  std::shared_ptr<::internal::ITensorBuilder> getTensorBuilder(const std::type_index &type);
+  std::set<std::shared_ptr<::internal::ITensorBuilder>> getAllTensorBuilders();
 
 private:
-  IExecutionBuilder &_builder;
+  std::unordered_map<std::type_index, ::internal::Backend> _gen_map;
 };
 
-void ActivationBuilder::appendReLU(::arm_compute::ICLTensor *ifm_alloc)
+std::shared_ptr<::internal::IInitializerGenerator> BackendResolver::getInitializerGenerator(const std::type_index &type)
 {
-  const ::arm_compute::ActivationLayerInfo act_info{
-      ::arm_compute::ActivationLayerInfo::ActivationFunction::RELU};
-
-  auto fn = make_layer<::arm_compute::CLActivationLayer>();
+  return _gen_map.at(type).initializer_gen;
+}
 
-  fn->configure(ifm_alloc, nullptr, act_info);
+std::shared_ptr<::internal::IStageGenerator> BackendResolver::getStageGenerator(const std::type_index &type)
+{
+  return _gen_map.at(type).stage_gen;
+}
 
-  _builder.append(std::move(fn));
+std::shared_ptr<::internal::ITensorBuilder> BackendResolver::getTensorBuilder(const std::type_index &type)
+{
+  return getStageGenerator(type)->tensor_builder();
 }
 
-void ActivationBuilder::append(FuseCode code, ::arm_compute::ICLTensor *ifm_alloc)
+std::set<std::shared_ptr<::internal::ITensorBuilder>> BackendResolver::getAllTensorBuilders()
 {
-  switch (code)
+  std::set<std::shared_ptr<::internal::ITensorBuilder>> ret;
+  for (const auto &it : _gen_map)
   {
-    case ANEURALNETWORKS_FUSED_NONE:
-    {
-      // DO NOTHING
-      break;
-    }
-    case ANEURALNETWORKS_FUSED_RELU:
-    {
-      appendReLU(ifm_alloc);
-      break;
-    }
-    default:
-    {
-      throw std::runtime_error("Not supported, yet");
-    }
+    ret.insert(it.second.stage_gen->tensor_builder());
   }
+  return ret;
 }
 
-#include <arm_compute/runtime/CL/functions/CLConvolutionLayer.h>
+#include "internal/arm_compute/InitializerGenerator.h"
+#include "internal/arm_compute/StageGenerator.h"
+//#include "internal/cpu/InitializerGenerator.h"
+//#include "internal/cpu/StageGenerator.h"
 
 class Planner : public ::internal::tflite::op::NodeVisitor
 {
 public:
-  Planner(const ::internal::tflite::operand::Set &ctx, IPlanBuilder &builder)
-      : _ctx{ctx}, _builder{builder}
+  Planner(const ::internal::tflite::operand::Set &ctx, IPlanBuilder &builder, BackendResolver &backend_resolver)
+      : _ctx{ctx}, _builder{builder}, _backend_resolver(backend_resolver)
   {
-    // DO NOTHING
   }
 
 public:
@@ -256,6 +147,7 @@ public:
 private:
   const ::internal::tflite::operand::Set &_ctx;
   IPlanBuilder &_builder;
+  BackendResolver &_backend_resolver;
 };
 
 void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node)
@@ -266,120 +158,25 @@ void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node)
   const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
   const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
 
-  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
-  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
-
-  const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
-  const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
-
   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
   const auto ker_shape = _ctx.at(ker_index).shape().asKernel();
   const auto bias_size = _ctx.at(bias_index).shape().asVector();
 
-  const PaddingCode padding_type =
-      static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
-
-  Stride stride;
-
-  stride.vertical = _ctx.at(vstride_index).asScalar<int32_t>();
-  stride.horizontal = _ctx.at(hstride_index).asScalar<int32_t>();
-
-  assert((ANEURALNETWORKS_PADDING_SAME == padding_type) ||
-         (ANEURALNETWORKS_PADDING_VALID == padding_type));
-
   // Set Shape Constraints
   _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape));
   _builder.addShapeConstr(ifm_index, asTensorInfo(ifm_shape));
   _builder.addShapeConstr(ker_index, asTensorInfo(ker_shape));
   _builder.addShapeConstr(bias_index, asTensorInfo(bias_size));
 
-  // Set initializer for kernel
-  {
-    auto ker_base = _ctx.at(ker_index).data().base();
-    auto ker_size = _ctx.at(ker_index).data().size();
-
-    auto initializer = [ker_shape, ker_base, ker_size](::arm_compute::ITensor &tensor) {
-      const ::internal::nnapi::kernel::Reader<float> from{ker_shape, ker_base, ker_size};
-      ::internal::arm_compute::kernel::View<float> into{&tensor};
-
-      ::nnfw::util::kernel::iterate(ker_shape)
-          << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
-               const auto value = from.at(nth, ch, row, col);
-               into.at(nth, ch, row, col) = value;
-             };
-    };
-
-    _builder.addInitializer(ker_index, initializer);
-  }
-
-  // Set initializer for bias
-  {
-    auto bias_base = _ctx.at(bias_index).data().base();
-
-    auto initializer = [bias_base, bias_size](::arm_compute::ITensor &tensor) {
-      for (uint32_t n = 0; n < bias_size; ++n)
-      {
-        const ::arm_compute::Coordinates coordinate{n};
+  // Generate Initializers
+  auto init_gen = _backend_resolver.getInitializerGenerator(typeid(node));
+  _builder.addInitializer(ker_index, init_gen->generateWeight(node));
+  _builder.addInitializer(bias_index, init_gen->generateBias(node));
 
-        float *into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
-
-        const float *from = reinterpret_cast<const float *>(bias_base) + n;
-        const auto value = *from;
-
-        *into = value;
-      }
-    };
-
-    _builder.addInitializer(bias_index, initializer);
-  }
-
-  // Construct operation parameters
-  struct Param
-  {
-    int ofm_index;
-    int ifm_index;
-    int ker_index;
-    int bias_index;
-
-    Padding padding;
-    Stride stride;
-
-    FuseCode activation;
-  };
-
-  Param param;
-
-  param.ofm_index = ofm_index.asInt();
-  param.ifm_index = ifm_index.asInt();
-  param.ker_index = ker_index.asInt();
-  param.bias_index = bias_index.asInt();
-
-  param.stride = stride;
-  param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
-                      ? same_padding(ifm_shape, ofm_shape, stride, ker_shape.W, ker_shape.H)
-                      : valid_padding();
-
-  param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScalar<int32_t>());
-
-  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
-    auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
-    auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
-    auto ker_alloc = ctx.at(::internal::tflite::operand::Index{param.ker_index});
-    auto bias_alloc = ctx.at(::internal::tflite::operand::Index{param.bias_index});
-
-    const auto conv_info = asPadStringInfo(param.padding, param.stride);
-
-    std::unique_ptr<::arm_compute::CLConvolutionLayer> fn{new ::arm_compute::CLConvolutionLayer};
-
-    fn->configure(ifm_alloc, ker_alloc, bias_alloc, ofm_alloc, conv_info);
-
-    builder.append(std::move(fn));
-
-    ActivationBuilder{builder}.append(param.activation, ofm_alloc);
-  };
-
-  _builder.addStage(stage);
+  // Generate Stage
+  auto stage_gen = _backend_resolver.getStageGenerator(typeid(node));
+  _builder.addStage(stage_gen->generate(node));
 }
 
 void Planner::visit(const ::internal::tflite::op::MaxPool2D::implicit::Node &node)
@@ -387,92 +184,16 @@ void Planner::visit(const ::internal::tflite::op::MaxPool2D::implicit::Node &nod
   const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
   const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
 
-  const ::internal::tflite::operand::Index kh_index{node.param().kh_index};
-  const ::internal::tflite::operand::Index kw_index{node.param().kw_index};
-
-  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
-  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
-
-  const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
-
   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
 
-  const int32_t kh = _ctx.at(kh_index).asScalar<int32_t>();
-  const int32_t kw = _ctx.at(kw_index).asScalar<int32_t>();
-
-  const int32_t vstride = _ctx.at(vstride_index).asScalar<int32_t>();
-  const int32_t hstride = _ctx.at(hstride_index).asScalar<int32_t>();
-
-  const PaddingCode padding_type =
-      static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
-
-  assert((ANEURALNETWORKS_PADDING_SAME == padding_type) ||
-         (ANEURALNETWORKS_PADDING_VALID == padding_type));
-
   // Set Shape Constraints
   _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape));
   _builder.addShapeConstr(ifm_index, asTensorInfo(ifm_shape));
 
-  // Construct operation parameters
-  struct Param
-  {
-    int ofm_index;
-    int ifm_index;
-
-    uint32_t kw;
-    uint32_t kh;
-
-    Padding padding;
-    Stride stride;
-
-    // TODO Add 'activation' field
-  };
-
-  Param param;
-
-  param.ofm_index = ofm_index.asInt();
-  param.ifm_index = ifm_index.asInt();
-
-  param.kh = kh;
-  param.kw = kw;
-
-  param.stride.vertical = vstride;
-  param.stride.horizontal = hstride;
-
-  param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
-                      ? same_padding(ifm_shape, ofm_shape, param.stride, kw, kh)
-                      : valid_padding();
-
-  VERBOSE(MaxPool2D) << "IFM_H: " << ifm_shape.H << std::endl;
-  VERBOSE(MaxPool2D) << "IFM_W: " << ifm_shape.W << std::endl;
-  VERBOSE(MaxPool2D) << "OFM_H: " << ofm_shape.H << std::endl;
-  VERBOSE(MaxPool2D) << "OFM_W: " << ofm_shape.W << std::endl;
-  VERBOSE(MaxPool2D) << "KER_H: " << kh << std::endl;
-  VERBOSE(MaxPool2D) << "KER_W: " << kw << std::endl;
-  VERBOSE(MaxPool2D) << "STRIDE_H: " << vstride << std::endl;
-  VERBOSE(MaxPool2D) << "STRIDE_W: " << hstride << std::endl;
-  VERBOSE(MaxPool2D) << "PAD(T): " << param.padding.top << std::endl;
-  VERBOSE(MaxPool2D) << "PAD(B): " << param.padding.bottom << std::endl;
-  VERBOSE(MaxPool2D) << "PAD(L): " << param.padding.left << std::endl;
-  VERBOSE(MaxPool2D) << "PAD(R): " << param.padding.right << std::endl;
-
-  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
-    auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
-    auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
-
-    ::arm_compute::PoolingLayerInfo info{::arm_compute::PoolingType::MAX,
-                                         ::arm_compute::Size2D{param.kw, param.kh},
-                                         asPadStringInfo(param.padding, param.stride)};
-
-    std::unique_ptr<::arm_compute::CLPoolingLayer> fn{new ::arm_compute::CLPoolingLayer};
-
-    fn->configure(ifm_alloc, ofm_alloc, info);
-
-    builder.append(std::move(fn));
-  };
-
-  _builder.addStage(stage);
+  // Generate Stage
+  auto stage_gen = _backend_resolver.getStageGenerator(typeid(node));
+  _builder.addStage(stage_gen->generate(node));
 }
 
 void Planner::visit(const ::internal::tflite::op::AvgPool2D::implicit::Node &node)
@@ -480,93 +201,16 @@ void Planner::visit(const ::internal::tflite::op::AvgPool2D::implicit::Node &nod
   const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
   const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
 
-  const ::internal::tflite::operand::Index kh_index{node.param().kh_index};
-  const ::internal::tflite::operand::Index kw_index{node.param().kw_index};
-
-  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
-  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
-
-  const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
-
   const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
   const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
 
-  const int32_t kh = _ctx.at(kh_index).asScalar<int32_t>();
-  const int32_t kw = _ctx.at(kw_index).asScalar<int32_t>();
-
-  const int32_t vstride = _ctx.at(vstride_index).asScalar<int32_t>();
-  const int32_t hstride = _ctx.at(hstride_index).asScalar<int32_t>();
-
-  const PaddingCode padding_type =
-      static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
-
-  assert((ANEURALNETWORKS_PADDING_SAME == padding_type) ||
-         (ANEURALNETWORKS_PADDING_VALID == padding_type));
-
   // Set Shape Constraints
   _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape));
   _builder.addShapeConstr(ifm_index, asTensorInfo(ifm_shape));
 
-  // Construct operation parameters
-  struct Param
-  {
-    int ofm_index;
-    int ifm_index;
-
-    uint32_t kw;
-    uint32_t kh;
-
-    Padding padding;
-    Stride stride;
-
-    // TODO Add 'activation' field
-  };
-
-  Param param;
-
-  param.ofm_index = ofm_index.asInt();
-  param.ifm_index = ifm_index.asInt();
-
-  param.kh = kh;
-  param.kw = kw;
-
-  param.stride.vertical = vstride;
-  param.stride.horizontal = hstride;
-
-  param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
-                      ? same_padding(ifm_shape, ofm_shape, param.stride, kw, kh)
-                      : valid_padding();
-
-  VERBOSE(AvgPool2D) << "IFM_H: " << ifm_shape.H << std::endl;
-  VERBOSE(AvgPool2D) << "IFM_W: " << ifm_shape.W << std::endl;
-  VERBOSE(AvgPool2D) << "OFM_H: " << ofm_shape.H << std::endl;
-  VERBOSE(AvgPool2D) << "OFM_W: " << ofm_shape.W << std::endl;
-  VERBOSE(AvgPool2D) << "KER_H: " << kh << std::endl;
-  VERBOSE(AvgPool2D) << "KER_W: " << kw << std::endl;
-  VERBOSE(AvgPool2D) << "STRIDE_H: " << vstride << std::endl;
-  VERBOSE(AvgPool2D) << "STRIDE_W: " << hstride << std::endl;
-  VERBOSE(AvgPool2D) << "PAD: " << to_string(padding_type) << std::endl;
-  VERBOSE(AvgPool2D) << "PAD(T): " << param.padding.top << std::endl;
-  VERBOSE(AvgPool2D) << "PAD(B): " << param.padding.bottom << std::endl;
-  VERBOSE(AvgPool2D) << "PAD(L): " << param.padding.left << std::endl;
-  VERBOSE(AvgPool2D) << "PAD(R): " << param.padding.right << std::endl;
-
-  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
-    auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
-    auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
-
-    ::arm_compute::PoolingLayerInfo info{
-        ::arm_compute::PoolingType::AVG, ::arm_compute::Size2D{param.kw, param.kh},
-        asPadStringInfo(param.padding, param.stride), true /* exclude_padding */};
-
-    std::unique_ptr<::arm_compute::CLPoolingLayer> fn{new ::arm_compute::CLPoolingLayer};
-
-    fn->configure(ifm_alloc, ofm_alloc, info);
-
-    builder.append(std::move(fn));
-  };
-
-  _builder.addStage(stage);
+  // Generate Stage
+  auto stage_gen = _backend_resolver.getStageGenerator(typeid(node));
+  _builder.addStage(stage_gen->generate(node));
 }
 
 void Planner::visit(const ::internal::tflite::op::Concat::Node &node)
@@ -592,14 +236,12 @@ void Planner::visit(const ::internal::tflite::op::Concat::Node &node)
   {
     const ::internal::tflite::operand::Index ifm_index{index};
     const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
-
-    _builder.addSubsumptionConstr(ifm_index, ofm_index, ::arm_compute::Coordinates{0, 0, depth, 0},
-                                  asTensorShape(ifm_shape));
-
-    depth += ifm_shape.C;
+    _builder.addShapeConstr(ifm_index, asTensorInfo(ifm_shape));
   }
 
-  // NOTE Concat has no actual operation!
+  // Generate Stage
+  auto stage_gen = _backend_resolver.getStageGenerator(typeid(node));
+  _builder.addStage(stage_gen->generate(node));
 }
 
 void Planner::visit(const ::internal::tflite::op::FullyConnected::Node &node)
@@ -634,98 +276,14 @@ void Planner::visit(const ::internal::tflite::op::FullyConnected::Node &node)
   _builder.addShapeConstr(weight_index, asTensorInfo(num_output /*H*/, input_size /*W*/));
   _builder.addShapeConstr(bias_index, asTensorInfo(bias_size));
 
-  // Set initializer for weight
-  {
-    auto weight_base = _ctx.at(weight_index).data().base();
-    auto weight_size = _ctx.at(weight_index).data().size();
-
-    auto initializer = [num_output, ifm_shape, weight_base,
-                        weight_size](::arm_compute::ITensor &tensor) {
-      const ::nnfw::util::kernel::Shape ker_shape{num_output, ifm_shape.C, ifm_shape.H,
-                                                  ifm_shape.W};
-      const ::internal::nnapi::kernel::Reader<float> from{ker_shape, weight_base, weight_size};
-
-      ::nnfw::util::kernel::iterate(ker_shape)
-          << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
-               const auto value = from.at(nth, ch, row, col);
-
-               uint32_t offset = 0;
-
-               // ARM Compute Library uses 'NCHW' ordering
-               offset += nth * ifm_shape.C * ifm_shape.H * ifm_shape.W;
-               offset += ch * ifm_shape.H * ifm_shape.W;
-               offset += row * ifm_shape.W;
-               offset += col;
-
-               const ::arm_compute::Coordinates coordinate{offset};
-
-               auto into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
-
-               *into = value;
-             };
-    };
-
-    _builder.addInitializer(weight_index, initializer);
-  }
-
-  // Set initializer for bias
-  {
-    auto bias_base = _ctx.at(bias_index).data().base();
-
-    auto initializer = [bias_base, bias_size](::arm_compute::ITensor &tensor) {
-      for (uint32_t n = 0; n < bias_size; ++n)
-      {
-        const ::arm_compute::Coordinates coordinate{n};
-
-        float *into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
-
-        const float *from = reinterpret_cast<const float *>(bias_base) + n;
-        const auto value = *from;
-
-        *into = value;
-      }
-    };
-
-    _builder.addInitializer(bias_index, initializer);
-  }
-
-  // Construct operation parameters
-  struct Param
-  {
-    int output_index;
-
-    int input_index;
-    int weight_index;
-    int bias_index;
-
-    FuseCode activation;
-  };
-
-  Param param;
+  // Generate Initializers
+  auto init_gen = _backend_resolver.getInitializerGenerator(typeid(node));
+  _builder.addInitializer(weight_index, init_gen->generateWeight(node));
+  _builder.addInitializer(bias_index, init_gen->generateBias(node));
 
-  param.output_index = output_index.asInt();
-  param.input_index = input_index.asInt();
-  param.weight_index = weight_index.asInt();
-  param.bias_index = bias_index.asInt();
-
-  param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScalar<int32_t>());
-
-  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
-    auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
-    auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
-    auto weight_alloc = ctx.at(::internal::tflite::operand::Index{param.weight_index});
-    auto bias_alloc = ctx.at(::internal::tflite::operand::Index{param.bias_index});
-
-    auto fn = make_layer<::arm_compute::CLFullyConnectedLayer>();
-
-    fn->configure(input_alloc, weight_alloc, bias_alloc, output_alloc);
-
-    builder.append(std::move(fn));
-
-    ActivationBuilder{builder}.append(param.activation, output_alloc);
-  };
-
-  _builder.addStage(stage);
+  // Generate Stage
+  auto stage_gen = _backend_resolver.getStageGenerator(typeid(node));
+  _builder.addStage(stage_gen->generate(node));
 }
 
 void Planner::visit(const ::internal::tflite::op::Reshape::Node &node)
@@ -753,29 +311,9 @@ void Planner::visit(const ::internal::tflite::op::Reshape::Node &node)
   _builder.addShapeConstr(output_index, asTensorInfo(out_size));
   _builder.addShapeConstr(input_index, asTensorInfo(ifm_shape));
 
-  struct Param
-  {
-    int output_index;
-    int input_index;
-  };
-
-  Param param;
-
-  param.output_index = output_index.asInt();
-  param.input_index = input_index.asInt();
-
-  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
-    auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
-    auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
-
-    auto fn = make_layer<::arm_compute::CLReshapeLayer>();
-
-    fn->configure(input_alloc, output_alloc);
-
-    builder.append(std::move(fn));
-  };
-
-  _builder.addStage(stage);
+  // Generate Stage
+  auto stage_gen = _backend_resolver.getStageGenerator(typeid(node));
+  _builder.addStage(stage_gen->generate(node));
 }
 
 void Planner::visit(const ::internal::tflite::op::Softmax::Node &node)
@@ -798,52 +336,94 @@ void Planner::visit(const ::internal::tflite::op::Softmax::Node &node)
   _builder.addShapeConstr(output_index, asTensorInfo(len));
   _builder.addShapeConstr(input_index, asTensorInfo(len));
 
-  struct Param
-  {
-    int output_index;
-    int input_index;
-    float scale;
-  };
-
-  Param param;
-
-  param.output_index = output_index.asInt();
-  param.input_index = input_index.asInt();
-  // TODO Set scale correctly
-  param.scale = 1.0f;
-
-  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
-    auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
-    auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
-
-    auto fn = make_layer<::arm_compute::CLSoftmaxLayer>();
-
-    fn->configure(input_alloc, output_alloc, param.scale);
-
-    builder.append(std::move(fn));
-  };
-
-  _builder.addStage(stage);
+  // Generate Stage
+  auto stage_gen = _backend_resolver.getStageGenerator(typeid(node));
+  _builder.addStage(stage_gen->generate(node));
 }
 
-class AllocationContext final : public IAllocationContext
+class TensorMarker : public ::internal::tflite::op::NodeVisitor
 {
 public:
-  AllocationContext(::internal::arm_compute::Plan &plan) : _plan{plan}
+  TensorMarker(::internal::ITensorBuilder& tensor_builder)
+      : _tensor_builder{tensor_builder}
   {
     // DO NOTHING
   }
 
 public:
-  ::arm_compute::ICLTensor *at(const ::internal::tflite::operand::Index &ind) const override
+  void visit(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+  void visit(const ::internal::tflite::op::MaxPool2D::implicit::Node &node) override;
+  void visit(const ::internal::tflite::op::AvgPool2D::implicit::Node &node) override;
+  void visit(const ::internal::tflite::op::Concat::Node &node) override;
+  void visit(const ::internal::tflite::op::FullyConnected::Node &node) override;
+  void visit(const ::internal::tflite::op::Reshape::Node &node) override;
+  void visit(const ::internal::tflite::op::Softmax::Node &node) override;
+
+private:
+  void mark(int32_t ind)
   {
-    return _plan.operands().at(ind).ptr();
+    _tensor_builder.mark(::internal::tflite::operand::Index{ind});
   }
 
 private:
-  ::internal::arm_compute::Plan &_plan;
+  ::internal::ITensorBuilder &_tensor_builder;
 };
 
+void TensorMarker::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+  const auto& param = node.param();
+  mark(param.ofm_index);
+  mark(param.ifm_index);
+  mark(param.ker_index);
+  mark(param.bias_index);
+}
+
+void TensorMarker::visit(const ::internal::tflite::op::MaxPool2D::implicit::Node &node)
+{
+  const auto& param = node.param();
+  mark(param.ofm_index);
+  mark(param.ifm_index);
+}
+
+void TensorMarker::visit(const ::internal::tflite::op::AvgPool2D::implicit::Node &node)
+{
+  const auto& param = node.param();
+  mark(param.ofm_index);
+  mark(param.ifm_index);
+}
+
+void TensorMarker::visit(const ::internal::tflite::op::Concat::Node &node)
+{
+  const auto& param = node.param();
+  for (auto ind : param.ifm_indexes)
+  {
+    mark(ind);
+  }
+}
+
+void TensorMarker::visit(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+  const auto& param = node.param();
+  mark(param.output_index);
+  mark(param.input_index);
+  mark(param.weight_index);
+  mark(param.bias_index);
+}
+
+void TensorMarker::visit(const ::internal::tflite::op::Reshape::Node &node)
+{
+  const auto& param = node.param();
+  mark(param.output_index);
+  mark(param.input_index);
+}
+
+void TensorMarker::visit(const ::internal::tflite::op::Softmax::Node &node)
+{
+  const auto& param = node.param();
+  mark(param.output_index);
+  mark(param.input_index);
+}
+
 class ExecutionBuilder final : public IExecutionBuilder
 {
 public:
@@ -875,12 +455,6 @@ public:
                       const ::arm_compute::TensorInfo &info) override;
 
 public:
-  void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
-                            const ::internal::tflite::operand::Index &base,
-                            const ::arm_compute::Coordinates &offset,
-                            const ::arm_compute::TensorShape &shape) override;
-
-public:
   void addInitializer(const ::internal::tflite::operand::Index &ind,
                       const Initializer &initializer) override;
 
@@ -888,36 +462,19 @@ public:
   void addStage(const Stage &stage) override;
 
 public:
-  void finalize(void) const;
+  void finalize(const std::set<std::shared_ptr<::internal::ITensorBuilder>> &tensor_builders);
 
-private:
-  ::internal::arm_compute::Plan &_plan;
+public:
+  const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx()
+  {
+    return _tensor_info_ctx;
+  }
 
 private:
-  struct Subsumption
-  {
-  public:
-    Subsumption(const ::internal::tflite::operand::Index &base,
-                const ::arm_compute::Coordinates &offset, const ::arm_compute::TensorShape &shape)
-        : _base{base}, _offset{offset}, _shape{shape}
-    {
-      // DO NOTHING
-    }
-
-  public:
-    const ::internal::tflite::operand::Index &base(void) const { return _base; }
-    const ::arm_compute::Coordinates &offset(void) const { return _offset; }
-    const ::arm_compute::TensorShape &shape(void) const { return _shape; }
-
-  private:
-    const ::internal::tflite::operand::Index _base;
-    const ::arm_compute::Coordinates _offset;
-    const ::arm_compute::TensorShape _shape;
-  };
+  ::internal::arm_compute::Plan &_plan;
 
 private:
   std::map<int, ::arm_compute::TensorInfo> _tensor_info_ctx;
-  std::map<int, std::shared_ptr<Subsumption>> _subsumption_ctx;
   std::map<int, Initializer> _initializer_ctx;
   std::vector<Stage> _stages;
 };
@@ -928,14 +485,6 @@ void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &ind,
   _tensor_info_ctx[ind.asInt()] = info;
 }
 
-void PlanBuilder::addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
-                                       const ::internal::tflite::operand::Index &base,
-                                       const ::arm_compute::Coordinates &offset,
-                                       const ::arm_compute::TensorShape &shape)
-{
-  _subsumption_ctx[ind.asInt()] = std::make_shared<Subsumption>(base, offset, shape);
-}
-
 void PlanBuilder::addInitializer(const ::internal::tflite::operand::Index &ind,
                                  const Initializer &initializer)
 {
@@ -944,107 +493,30 @@ void PlanBuilder::addInitializer(const ::internal::tflite::operand::Index &ind,
 
 void PlanBuilder::addStage(const Stage &stage) { _stages.emplace_back(stage); }
 
-#include <stack>
-
-void PlanBuilder::finalize(void) const
+void PlanBuilder::finalize(const std::set<std::shared_ptr<::internal::ITensorBuilder>> &tensor_builders)
 {
-  // CLTensor objects to be initialized later
-  std::vector<std::shared_ptr<::arm_compute::CLTensor>> tensors;
-
-  // Create CLTensor & CLSubTensor
-  auto isAllocated = [this](int ind) {
-    const ::internal::tflite::operand::Index operand_index{ind};
-    return _plan.operands().exist(operand_index);
-  };
-
-  auto setCLTensor = [&](int ind) {
-    auto tensor = std::make_shared<::arm_compute::CLTensor>();
-
-    tensor->allocator()->init(_tensor_info_ctx.at(ind));
-
-    // NOTE Do NOT allocate here. allocate should be invoked after configure functions
-    _plan.operands().set(::internal::tflite::operand::Index{ind}, tensor);
-    tensors.emplace_back(tensor);
-  };
-
-  auto setCLSubTensor = [&](int curr) {
-    const auto &sub_info = *(_subsumption_ctx.find(curr)->second);
-
-    auto base_tensor = _plan.operands().at(sub_info.base()).ptr();
+  // Mark tensors
+  const auto &operations = _plan.model().operations();
 
-    assert(base_tensor != nullptr);
-
-    auto curr_tensor = std::make_shared<::arm_compute::CLSubTensor>(base_tensor, sub_info.shape(),
-                                                                    sub_info.offset());
-
-    _plan.operands().set(::internal::tflite::operand::Index{curr}, curr_tensor);
-  };
-
-  for (auto it = _subsumption_ctx.begin(); it != _subsumption_ctx.end(); ++it)
-  {
-    std::stack<int> stack;
-
-    stack.push(it->first);
-
-    while (!stack.empty())
-    {
-      const auto curr = stack.top();
-
-      if (isAllocated(curr))
-      {
-        // Skip if already allocated
-        stack.pop();
-        continue;
-      }
-
-      auto it_s = _subsumption_ctx.find(curr);
-
-      if (it_s == _subsumption_ctx.end())
-      {
-        setCLTensor(curr);
-        stack.pop();
-        continue;
-      }
-
-      const auto &sub_info = *(it_s->second);
-
-      if (isAllocated(sub_info.base().asInt()))
-      {
-        setCLSubTensor(curr);
-        stack.pop();
-      }
-      else
-      {
-        // Allocate base tensor first
-        stack.push(sub_info.base().asInt());
-      }
-    }
-  }
-
-  for (auto it = _tensor_info_ctx.begin(); it != _tensor_info_ctx.end(); ++it)
+  // Prepare tensors
+  for (auto &tensor_builder : tensor_builders)
   {
-    if (isAllocated(it->first))
-    {
-      // Skip if already allocated
-      continue;
-    }
-
-    setCLTensor(it->first);
+    tensor_builder->prepare(_tensor_info_ctx);
   }
 
   // Process Stage
-  AllocationContext allocation_context{_plan};
   ExecutionBuilder execution_builder{_plan};
 
   for (const auto &stage : _stages)
   {
-    stage(allocation_context, execution_builder);
+    stage(execution_builder);
   }
 
-  // Allocate Tensor Memory
-  for (const auto &tensor : tensors)
+  // TODO Add code for CPU/ACL tensor allocation
+  // Allocate Tensor Memory for cl_tensors
+  for (auto &tensor_builder : tensor_builders)
   {
-    tensor->allocator()->allocate();
+    tensor_builder->allocate();
   }
 
   // Fill weight/bias
@@ -1074,17 +546,37 @@ int ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation *compilation)
 {
   arm_compute::CLScheduler::get().default_init();
 
-  const auto &operands = compilation->plan().model().operands();
-  const auto &operations = compilation->plan().model().operations();
+  auto &plan = compilation->plan();
+  const auto &operands = plan.model().operands();
+  const auto &operations = plan.model().operations();
 
-  PlanBuilder plan_builder{compilation->plan()};
+  ::internal::BackendManager backend_manager{plan};
+  BackendResolver backend_resolver{backend_manager};
+  PlanBuilder plan_builder{plan};
 
   for (uint32_t n = 0; n < operations.size(); ++n)
   {
-    operations.at(n).accept(Planner{operands, plan_builder});
+    operations.at(n).accept(Planner{operands, plan_builder, backend_resolver});
+  }
+
+  // TODO Add optimization passes
+
+  for (uint32_t n = 0; n < operations.size(); ++n)
+  {
+    const auto& op = operations.at(n);
+    auto tensor_builder = backend_resolver.getTensorBuilder(typeid(op));
+    op.accept(TensorMarker{*tensor_builder});
+  }
+
+  /*
+  for (auto it : plan_builder.tensor_info_ctx())
+  {
+    auto tensor_builder = backend_manager.getTensorBuilder("arm_compute");
+    tensor_builder->mark(::internal::tflite::operand::Index{it.first});
   }
+  */
 
-  plan_builder.finalize();
+  plan_builder.finalize(backend_resolver.getAllTensorBuilders());
 
   return ANEURALNETWORKS_NO_ERROR;
 }
diff --git a/runtimes/new_runtime/src/internal/BackendManager.cc b/runtimes/new_runtime/src/internal/BackendManager.cc
new file mode 100644 (file)
index 0000000..ca11ad1
--- /dev/null
@@ -0,0 +1,31 @@
+#include "internal/BackendManager.h"
+
+#include "internal/arm_compute/TensorBuilder.h"
+#include "internal/arm_compute/InitializerGenerator.h"
+#include "internal/arm_compute/StageGenerator.h"
+
+namespace internal
+{
+
+BackendManager::BackendManager(::internal::arm_compute::Plan& plan) : _plan(plan)
+{
+  const auto &operands = _plan.model().operands();
+
+  // Add arm_compute backend
+  {
+    auto acl_tensor_builder = std::make_shared<::internal::arm_compute::TensorBuilder>(_plan);
+    auto acl_initializer_gen = std::make_shared<::internal::arm_compute::InitializerGenerator>(operands);
+    auto acl_stage_gen = std::make_shared<::internal::arm_compute::StageGenerator>(operands, acl_tensor_builder);
+
+    _gen_map["arm_compute"] = {acl_initializer_gen, acl_stage_gen};
+  }
+
+  // TODO Add CPU backend
+}
+
+Backend BackendManager::get(const std::string &key)
+{
+  return _gen_map.at(key);
+}
+
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/BackendManager.h b/runtimes/new_runtime/src/internal/BackendManager.h
new file mode 100644 (file)
index 0000000..eabbe8d
--- /dev/null
@@ -0,0 +1,47 @@
+#ifndef __INTERNAL_BACKEND_MANAGER_H__
+#define __INTERNAL_BACKEND_MANAGER_H__
+
+#include <memory>
+
+#include "internal/arm_compute.h"
+#include "internal/IInitializerGenerator.h"
+#include "internal/IStageGenerator.h"
+#include "internal/ITensorBuilder.h"
+
+namespace internal
+{
+
+struct Backend
+{
+  std::shared_ptr<::internal::IInitializerGenerator> initializer_gen;
+  std::shared_ptr<::internal::IStageGenerator> stage_gen;
+
+  Backend(const std::shared_ptr<::internal::IInitializerGenerator> &initializer_gen,
+          const std::shared_ptr<::internal::IStageGenerator> &stage_gen)
+    : initializer_gen(initializer_gen), stage_gen(stage_gen)
+  {
+    // DO NOTHING
+  }
+
+  Backend(void)
+    : initializer_gen(nullptr), stage_gen(nullptr)
+  {
+    // DO NOTHING
+  }
+};
+
+class BackendManager
+{
+public:
+  BackendManager(::internal::arm_compute::Plan& plan);
+
+  Backend get(const std::string &key);
+
+private:
+  ::internal::arm_compute::Plan &_plan;
+  std::map<std::string, Backend> _gen_map;
+};
+
+} // namespace internal
+
+#endif // __INTERNAL_BACKEND_MANAGER_H__
diff --git a/runtimes/new_runtime/src/internal/IInitializerGenerator.h b/runtimes/new_runtime/src/internal/IInitializerGenerator.h
new file mode 100644 (file)
index 0000000..a09faf5
--- /dev/null
@@ -0,0 +1,26 @@
+#ifndef __INTERNAL_IINITIALIZER_GENERATOR_H__
+#define __INTERNAL_IINITIALIZER_GENERATOR_H__
+
+#include "arm_compute/core/ITensor.h"
+
+#include "internal/op/Conv2D.h"
+#include "internal/op/FullyConnected.h"
+
+using Initializer = std::function<void(::arm_compute::ITensor &)>;
+
+namespace internal
+{
+
+struct IInitializerGenerator {
+  virtual ~IInitializerGenerator() = default;
+
+  virtual Initializer generateWeight(const ::internal::tflite::op::Conv2D::implicit::Node &node) = 0;
+  virtual Initializer generateWeight(const ::internal::tflite::op::FullyConnected::Node &node) = 0;
+
+  virtual Initializer generateBias(const ::internal::tflite::op::Conv2D::implicit::Node &node) = 0;
+  virtual Initializer generateBias(const ::internal::tflite::op::FullyConnected::Node &node) = 0;
+};
+
+} // namespace internal
+
+#endif // __INTERNAL_IINITIALIZER_GENERATOR_H__
diff --git a/runtimes/new_runtime/src/internal/IObject.h b/runtimes/new_runtime/src/internal/IObject.h
new file mode 100644 (file)
index 0000000..1ea6615
--- /dev/null
@@ -0,0 +1,20 @@
+#ifndef __INTERNAL_IOBJECT_H__
+#define __INTERNAL_IOBJECT_H__
+
+#include <functional>
+
+#include <arm_compute/core/ITensor.h>
+
+namespace internal
+{
+
+struct IObject
+{
+  virtual ~IObject() = default;
+  virtual ::arm_compute::ITensor *ptr(void) const = 0;
+  virtual void access(const std::function<void(::arm_compute::ITensor &tensor)> &fn) const = 0;
+};
+
+} // namespace internal
+
+#endif // __INTERNAL_IOBJECT_H__
diff --git a/runtimes/new_runtime/src/internal/IStageGenerator.h b/runtimes/new_runtime/src/internal/IStageGenerator.h
new file mode 100644 (file)
index 0000000..9932e46
--- /dev/null
@@ -0,0 +1,47 @@
+#ifndef __INTERNAL_ISTAGE_GENERATOR_H__
+#define __INTERNAL_ISTAGE_GENERATOR_H__
+
+#include <memory>
+#include <functional>
+
+#include <arm_compute/runtime/IFunction.h>
+
+#include "internal/ITensorBuilder.h"
+#include "internal/op/Conv2D.h"
+#include "internal/op/MaxPool2D.h"
+#include "internal/op/AvgPool2D.h"
+#include "internal/op/Concat.h"
+#include "internal/op/FullyConnected.h"
+#include "internal/op/Reshape.h"
+#include "internal/op/Softmax.h"
+
+struct IExecutionBuilder
+{
+  virtual ~IExecutionBuilder() = default;
+
+  virtual void append(std::unique_ptr<::arm_compute::IFunction> &&f) = 0;
+};
+
+using Stage = std::function<void(IExecutionBuilder &)>;
+
+namespace internal
+{
+
+struct IStageGenerator
+{
+  virtual ~IStageGenerator() = default;
+
+  virtual std::shared_ptr<ITensorBuilder> tensor_builder() = 0;
+
+  virtual Stage generate(const ::internal::tflite::op::Conv2D::implicit::Node &node) = 0;
+  virtual Stage generate(const ::internal::tflite::op::MaxPool2D::implicit::Node &node) = 0;
+  virtual Stage generate(const ::internal::tflite::op::AvgPool2D::implicit::Node &node) = 0;
+  virtual Stage generate(const ::internal::tflite::op::Concat::Node &node) = 0;
+  virtual Stage generate(const ::internal::tflite::op::FullyConnected::Node &node) = 0;
+  virtual Stage generate(const ::internal::tflite::op::Reshape::Node &node) = 0;
+  virtual Stage generate(const ::internal::tflite::op::Softmax::Node &node) = 0;
+};
+
+} // namespace internal
+
+#endif // __INTERNAL_ISTAGE_GENERATOR_H__
diff --git a/runtimes/new_runtime/src/internal/ITensorBuilder.h b/runtimes/new_runtime/src/internal/ITensorBuilder.h
new file mode 100644 (file)
index 0000000..2f6f5e4
--- /dev/null
@@ -0,0 +1,23 @@
+#ifndef __INTERNAL_ITENSOR_BUILDER_H__
+#define __INTERNAL_ITENSOR_BUILDER_H__
+
+#include <map>
+#include <arm_compute/core/TensorInfo.h>
+
+#include "internal/Model.h"
+
+namespace internal
+{
+
+struct ITensorBuilder
+{
+  virtual ~ITensorBuilder(void) = default;
+  virtual void mark(const ::internal::tflite::operand::Index& ind) = 0;
+  // TODO Add an interface for adding subsumption info
+  virtual void prepare(const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx) = 0;
+  virtual void allocate(void) = 0;
+};
+
+} // namespace internal
+
+#endif // __INTERNAL_ITENSOR_BUILDER_H__
diff --git a/runtimes/new_runtime/src/internal/Padding.cc b/runtimes/new_runtime/src/internal/Padding.cc
new file mode 100644 (file)
index 0000000..21ce014
--- /dev/null
@@ -0,0 +1,56 @@
+#include "internal/Padding.h"
+
+#include <algorithm>
+
+namespace internal
+{
+
+Padding valid_padding(void)
+{
+  //
+  // ANEURALNETWORKS_PADDING_VALID
+  //
+  // VALID padding. No padding.
+  //
+  // When the input size is not evenly divisible by the filter size,
+  // the input at the end that could not fill the whole filter tile
+  // will simply be ignored.
+  //
+  Padding padding;
+
+  padding.top = 0;
+  padding.bottom = 0;
+  padding.left = 0;
+  padding.right = 0;
+
+  return padding;
+}
+
+Padding same_padding(const nnfw::util::feature::Shape &ifm_shape,
+                     const nnfw::util::feature::Shape &ofm_shape, const Stride &stride, uint32_t kw,
+                     uint32_t kh)
+{
+  Padding padding;
+
+  // ANEURALNETWORKS_PADDING_SAME (from NNAPI spec)
+  //
+  // SAME padding. Padding on both ends are the "same":
+  //
+  //   padding_to_beginning = total_padding / 2
+  //  padding_to_end = (total_padding + 1)/2.
+  //
+  const int32_t vertical_needed_input = (ofm_shape.H - 1) * stride.vertical + kh;
+  const int32_t vertical_total_padding = std::max(0, vertical_needed_input - ifm_shape.H);
+
+  const int32_t horizontal_needed_input = (ofm_shape.W - 1) * stride.horizontal + kw;
+  const int32_t horizontal_total_padding = std::max(0, horizontal_needed_input - ifm_shape.W);
+
+  padding.top = vertical_total_padding / 2;
+  padding.bottom = (vertical_total_padding + 1) / 2;
+  padding.left = horizontal_total_padding / 2;
+  padding.right = (horizontal_total_padding + 1) / 2;
+
+  return padding;
+}
+
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/Padding.h b/runtimes/new_runtime/src/internal/Padding.h
new file mode 100644 (file)
index 0000000..d1b413b
--- /dev/null
@@ -0,0 +1,32 @@
+#ifndef __INTERNAL_PADDING_H__
+#define __INTERNAL_PADDING_H__
+
+#include <stdint.h>
+
+#include <util/feature/Shape.h>
+
+namespace internal
+{
+
+struct Padding
+{
+  uint32_t top;
+  uint32_t bottom;
+  uint32_t left;
+  uint32_t right;
+};
+
+struct Stride
+{
+  uint32_t vertical;
+  uint32_t horizontal;
+};
+
+Padding valid_padding(void);
+Padding same_padding(const nnfw::util::feature::Shape &ifm_shape,
+                     const nnfw::util::feature::Shape &ofm_shape, const Stride &stride, uint32_t kw,
+                     uint32_t kh);
+
+} // namespace internal
+
+#endif // __INTERNAL_PADDING_H__
index 394a64c..1e0097b 100644 (file)
@@ -32,11 +32,11 @@ namespace operand
 {
 
 Context &Context::set(const ::internal::tflite::operand::Index &id,
-                      const std::shared_ptr<::arm_compute::ICLTensor> &tensor)
+                      const std::shared_ptr<::internal::IObject> &object)
 {
   assert(_objects.find(id.asInt()) == _objects.end());
 
-  _objects[id.asInt()] = Object{tensor};
+  _objects[id.asInt()] = object;
   return (*this);
 }
 
index da44d12..3a1e333 100644 (file)
@@ -3,6 +3,8 @@
 
 #include <arm_compute/core/CL/ICLTensor.h>
 
+#include "internal/IObject.h"
+
 namespace internal
 {
 namespace arm_compute
@@ -10,7 +12,7 @@ namespace arm_compute
 namespace operand
 {
 
-class Object
+class Object : public ::internal::IObject
 {
 public:
   Object() = default;
@@ -22,13 +24,13 @@ public:
   }
 
 public:
-  ::arm_compute::ICLTensor *ptr(void) const { return _tensor.get(); }
+  ::arm_compute::ICLTensor *ptr(void) const override { return _tensor.get(); }
 
 private:
   std::shared_ptr<::arm_compute::ICLTensor> _tensor;
 
 public:
-  void access(const std::function<void(::arm_compute::ITensor &tensor)> &fn) const;
+  void access(const std::function<void(::arm_compute::ITensor &tensor)> &fn) const override;
 };
 
 } // namespace operand
@@ -50,7 +52,7 @@ class Context
 {
 public:
   Context &set(const ::internal::tflite::operand::Index &ind,
-               const std::shared_ptr<::arm_compute::ICLTensor> &tensor);
+               const std::shared_ptr<::internal::IObject> &object);
 
 public:
   bool exist(const ::internal::tflite::operand::Index &ind) const
@@ -59,15 +61,15 @@ public:
   }
 
 public:
-  const Object &at(const ::internal::tflite::operand::Index &ind) const
+  const ::internal::IObject &at(const ::internal::tflite::operand::Index &ind) const
   {
-    return _objects.at(ind.asInt());
+    return *_objects.at(ind.asInt());
   }
 
-  Object &at(const ::internal::tflite::operand::Index &ind) { return _objects.at(ind.asInt()); }
+  ::internal::IObject &at(const ::internal::tflite::operand::Index &ind) { return *_objects.at(ind.asInt()); }
 
 private:
-  std::map<int, Object> _objects;
+  std::map<int, std::shared_ptr<IObject>> _objects;
 };
 
 } // namespace operand
@@ -111,6 +113,7 @@ namespace internal
 namespace arm_compute
 {
 
+// TODO class Plan should not be in `::internal::arm_compute`, we should put it in `::internal`
 class Plan
 {
 public:
diff --git a/runtimes/new_runtime/src/internal/arm_compute/InitializerGenerator.cc b/runtimes/new_runtime/src/internal/arm_compute/InitializerGenerator.cc
new file mode 100644 (file)
index 0000000..a0ea8c5
--- /dev/null
@@ -0,0 +1,126 @@
+#include "internal/arm_compute/InitializerGenerator.h"
+
+#include <arm_compute/core/Coordinates.h>
+
+#include "internal/arm_compute/kernel/View.h"
+#include "internal/nnapi/kernel/Reader.h"
+#include "util/kernel/IndexIterator.h"
+
+namespace internal
+{
+namespace arm_compute
+{
+
+InitializerGenerator::InitializerGenerator(const ::internal::tflite::operand::Set &ctx) : _ctx(ctx)
+{
+  // DO NOTHING
+}
+
+Initializer InitializerGenerator::generateWeight(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+  const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
+
+  const auto ker_shape = _ctx.at(ker_index).shape().asKernel();
+  auto ker_base = _ctx.at(ker_index).data().base();
+  auto ker_size = _ctx.at(ker_index).data().size();
+
+  return [ker_shape, ker_base, ker_size](::arm_compute::ITensor &tensor) {
+    const ::internal::nnapi::kernel::Reader<float> from{ker_shape, ker_base, ker_size};
+    ::internal::arm_compute::kernel::View<float> into{&tensor};
+
+    ::nnfw::util::kernel::iterate(ker_shape)
+        << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
+             const auto value = from.at(nth, ch, row, col);
+             into.at(nth, ch, row, col) = value;
+           };
+  };
+}
+
+Initializer InitializerGenerator::generateWeight(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+  const ::internal::tflite::operand::Index weight_index{node.param().weight_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+
+  const auto num_output = _ctx.at(weight_index).shape().dim(0);
+  auto weight_base = _ctx.at(weight_index).data().base();
+  auto weight_size = _ctx.at(weight_index).data().size();
+
+  // NOTE We assume that input is a feature map
+  // TODO Remove this restriction!
+  const auto ifm_shape = _ctx.at(input_index).shape().asFeature();
+
+  return [num_output, ifm_shape, weight_base,
+                      weight_size](::arm_compute::ITensor &tensor) {
+    const ::nnfw::util::kernel::Shape ker_shape{num_output, ifm_shape.C, ifm_shape.H,
+                                                ifm_shape.W};
+    const ::internal::nnapi::kernel::Reader<float> from{ker_shape, weight_base, weight_size};
+
+    ::nnfw::util::kernel::iterate(ker_shape)
+        << [&](uint32_t nth, uint32_t ch, uint32_t row, uint32_t col) {
+             const auto value = from.at(nth, ch, row, col);
+
+             uint32_t offset = 0;
+
+             // ARM Compute Library uses 'NCHW' ordering
+             offset += nth * ifm_shape.C * ifm_shape.H * ifm_shape.W;
+             offset += ch * ifm_shape.H * ifm_shape.W;
+             offset += row * ifm_shape.W;
+             offset += col;
+
+             const ::arm_compute::Coordinates coordinate{offset};
+
+             auto into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
+
+             *into = value;
+           };
+  };
+}
+
+Initializer InitializerGenerator::generateBias(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+  // TODO Refactor so we can reuse the common code
+
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+  auto bias_base = _ctx.at(bias_index).data().base();
+  const auto bias_size = _ctx.at(bias_index).shape().asVector();
+
+  return [bias_base, bias_size](::arm_compute::ITensor &tensor) {
+    for (uint32_t n = 0; n < bias_size; ++n)
+    {
+      const ::arm_compute::Coordinates coordinate{n};
+
+      float *into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
+
+      const float *from = reinterpret_cast<const float *>(bias_base) + n;
+      const auto value = *from;
+
+      *into = value;
+    }
+  };
+}
+
+Initializer InitializerGenerator::generateBias(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+  auto bias_base = _ctx.at(bias_index).data().base();
+  const auto bias_size = _ctx.at(bias_index).shape().asVector();
+
+  return [bias_base, bias_size](::arm_compute::ITensor &tensor) {
+    for (uint32_t n = 0; n < bias_size; ++n)
+    {
+      const ::arm_compute::Coordinates coordinate{n};
+
+      float *into = reinterpret_cast<float *>(tensor.ptr_to_element(coordinate));
+
+      const float *from = reinterpret_cast<const float *>(bias_base) + n;
+      const auto value = *from;
+
+      *into = value;
+    }
+  };
+}
+
+} // namespace arm_compute
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/arm_compute/InitializerGenerator.h b/runtimes/new_runtime/src/internal/arm_compute/InitializerGenerator.h
new file mode 100644 (file)
index 0000000..3bf610d
--- /dev/null
@@ -0,0 +1,31 @@
+#ifndef __INTERNAL_ARM_COMPUTE_INITIALIZER_GENERATOR_H__
+#define __INTERNAL_ARM_COMPUTE_INITIALIZER_GENERATOR_H__
+
+#include "internal/IInitializerGenerator.h"
+
+#include "internal/Model.h"
+
+namespace internal
+{
+namespace arm_compute
+{
+
+class InitializerGenerator : public ::internal::IInitializerGenerator
+{
+public:
+  InitializerGenerator(const ::internal::tflite::operand::Set &ctx);
+
+  Initializer generateWeight(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+  Initializer generateWeight(const ::internal::tflite::op::FullyConnected::Node &node) override;
+
+  Initializer generateBias(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+  Initializer generateBias(const ::internal::tflite::op::FullyConnected::Node &node) override;
+
+private:
+  const ::internal::tflite::operand::Set &_ctx;
+};
+
+} // namespace arm_compute
+} // namespace internal
+
+#endif // __INTERNAL_ARM_COMPUTE_INITIALIZER_GENERATOR_H__
diff --git a/runtimes/new_runtime/src/internal/arm_compute/StageGenerator.cc b/runtimes/new_runtime/src/internal/arm_compute/StageGenerator.cc
new file mode 100644 (file)
index 0000000..dc878b0
--- /dev/null
@@ -0,0 +1,488 @@
+#include "internal/arm_compute/StageGenerator.h"
+
+#include <arm_compute/runtime/CL/functions/CLConvolutionLayer.h>
+#include <arm_compute/runtime/CL/functions/CLPoolingLayer.h>
+#include <arm_compute/runtime/CL/functions/CLActivationLayer.h>
+#include <arm_compute/runtime/CL/functions/CLReshapeLayer.h>
+#include <arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h>
+#include <arm_compute/runtime/CL/functions/CLSoftmaxLayer.h>
+
+#include "internal/Padding.h"
+#include "internal/Model.h"
+
+#include "logging.h"
+
+#include "NeuralNetworks.h"
+
+const char *to_string(const PaddingCode &code)
+{
+  assert((ANEURALNETWORKS_PADDING_SAME == code) || (ANEURALNETWORKS_PADDING_VALID == code));
+
+  switch (code)
+  {
+    case ANEURALNETWORKS_PADDING_SAME:
+      return "ANEURALNETWORKS_PADDING_SAME";
+    case ANEURALNETWORKS_PADDING_VALID:
+      return "ANEURALNETWORKS_PADDING_VALID";
+  }
+
+  return nullptr;
+}
+
+template <typename T> std::unique_ptr<T> make_layer(void) { return std::unique_ptr<T>{new T}; }
+
+::arm_compute::PadStrideInfo asPadStringInfo(const ::internal::Padding &padding, const ::internal::Stride &stride)
+{
+  return ::arm_compute::PadStrideInfo{stride.horizontal,
+                                      stride.vertical,
+                                      padding.left,
+                                      padding.right,
+                                      padding.top,
+                                      padding.bottom,
+                                      ::arm_compute::DimensionRoundingType::FLOOR};
+}
+
+namespace internal
+{
+namespace arm_compute
+{
+
+//
+// ActivationBuilder
+//
+class ActivationBuilder
+{
+public:
+  ActivationBuilder(IExecutionBuilder &builder) : _builder(builder)
+  {
+    // DO NOTHING
+  }
+
+private:
+  void appendReLU(::arm_compute::ICLTensor *tensor);
+
+public:
+  void append(FuseCode code, ::arm_compute::ICLTensor *tensor);
+
+private:
+  IExecutionBuilder &_builder;
+};
+
+void ActivationBuilder::appendReLU(::arm_compute::ICLTensor *ifm_alloc)
+{
+  const ::arm_compute::ActivationLayerInfo act_info{
+      ::arm_compute::ActivationLayerInfo::ActivationFunction::RELU};
+
+  auto fn = make_layer<::arm_compute::CLActivationLayer>();
+
+  fn->configure(ifm_alloc, nullptr, act_info);
+
+  _builder.append(std::move(fn));
+}
+
+void ActivationBuilder::append(FuseCode code, ::arm_compute::ICLTensor *ifm_alloc)
+{
+  switch (code)
+  {
+    case ANEURALNETWORKS_FUSED_NONE:
+    {
+      // DO NOTHING
+      break;
+    }
+    case ANEURALNETWORKS_FUSED_RELU:
+    {
+      appendReLU(ifm_alloc);
+      break;
+    }
+    default:
+    {
+      throw std::runtime_error("Not supported, yet");
+    }
+  }
+}
+
+//
+// StageGenerator
+//
+StageGenerator::StageGenerator(const ::internal::tflite::operand::Set &ctx,
+                               const std::shared_ptr<::internal::arm_compute::TensorBuilder> &tensor_builder)
+    : _ctx(ctx), _tensor_builder(tensor_builder)
+{
+  // DO NOTHING
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Conv2D::implicit::Node &node)
+{
+  const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+  const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+  const ::internal::tflite::operand::Index ker_index{node.param().ker_index};
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+
+  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
+  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
+
+  const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
+  const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
+
+  const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
+  const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
+  const auto ker_shape = _ctx.at(ker_index).shape().asKernel();
+  const auto bias_size = _ctx.at(bias_index).shape().asVector();
+
+  const PaddingCode padding_type =
+      static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
+
+  assert((ANEURALNETWORKS_PADDING_SAME == padding_type) ||
+         (ANEURALNETWORKS_PADDING_VALID == padding_type));
+
+  Stride stride;
+
+  stride.vertical = _ctx.at(vstride_index).asScalar<int32_t>();
+  stride.horizontal = _ctx.at(hstride_index).asScalar<int32_t>();
+
+  // Construct operation parameters
+  struct Param
+  {
+    int ofm_index;
+    int ifm_index;
+    int ker_index;
+    int bias_index;
+
+    Padding padding;
+    Stride stride;
+
+    FuseCode activation;
+  };
+
+  Param param;
+
+  param.ofm_index = ofm_index.asInt();
+  param.ifm_index = ifm_index.asInt();
+  param.ker_index = ker_index.asInt();
+  param.bias_index = bias_index.asInt();
+
+  param.stride = stride;
+  param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
+                      ? same_padding(ifm_shape, ofm_shape, stride, ker_shape.W, ker_shape.H)
+                      : valid_padding();
+
+  param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScalar<int32_t>());
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto ofm_alloc = tensors->at(::internal::tflite::operand::Index{param.ofm_index}).get();
+    auto ifm_alloc = tensors->at(::internal::tflite::operand::Index{param.ifm_index}).get();
+    auto ker_alloc = tensors->at(::internal::tflite::operand::Index{param.ker_index}).get();
+    auto bias_alloc = tensors->at(::internal::tflite::operand::Index{param.bias_index}).get();
+
+    const auto conv_info = asPadStringInfo(param.padding, param.stride);
+
+    std::unique_ptr<::arm_compute::CLConvolutionLayer> fn{new ::arm_compute::CLConvolutionLayer};
+
+    fn->configure(ifm_alloc, ker_alloc, bias_alloc, ofm_alloc, conv_info);
+
+    builder.append(std::move(fn));
+
+    ActivationBuilder{builder}.append(param.activation, ofm_alloc);
+  };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::MaxPool2D::implicit::Node &node)
+{
+  const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+  const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+
+  const ::internal::tflite::operand::Index kh_index{node.param().kh_index};
+  const ::internal::tflite::operand::Index kw_index{node.param().kw_index};
+
+  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
+  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
+
+  const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
+
+  const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
+  const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
+
+  const int32_t kh = _ctx.at(kh_index).asScalar<int32_t>();
+  const int32_t kw = _ctx.at(kw_index).asScalar<int32_t>();
+
+  const int32_t vstride = _ctx.at(vstride_index).asScalar<int32_t>();
+  const int32_t hstride = _ctx.at(hstride_index).asScalar<int32_t>();
+
+  const PaddingCode padding_type =
+      static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
+
+  // Construct operation parameters
+  struct Param
+  {
+    int ofm_index;
+    int ifm_index;
+
+    uint32_t kw;
+    uint32_t kh;
+
+    Padding padding;
+    Stride stride;
+
+    // TODO Add 'activation' field
+  };
+
+  Param param;
+
+  param.ofm_index = ofm_index.asInt();
+  param.ifm_index = ifm_index.asInt();
+
+  param.kh = kh;
+  param.kw = kw;
+
+  param.stride.vertical = vstride;
+  param.stride.horizontal = hstride;
+
+  param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
+                      ? same_padding(ifm_shape, ofm_shape, param.stride, kw, kh)
+                      : valid_padding();
+
+  VERBOSE(MaxPool2D) << "IFM_H: " << ifm_shape.H << std::endl;
+  VERBOSE(MaxPool2D) << "IFM_W: " << ifm_shape.W << std::endl;
+  VERBOSE(MaxPool2D) << "OFM_H: " << ofm_shape.H << std::endl;
+  VERBOSE(MaxPool2D) << "OFM_W: " << ofm_shape.W << std::endl;
+  VERBOSE(MaxPool2D) << "KER_H: " << kh << std::endl;
+  VERBOSE(MaxPool2D) << "KER_W: " << kw << std::endl;
+  VERBOSE(MaxPool2D) << "STRIDE_H: " << vstride << std::endl;
+  VERBOSE(MaxPool2D) << "STRIDE_W: " << hstride << std::endl;
+  VERBOSE(MaxPool2D) << "PAD(T): " << param.padding.top << std::endl;
+  VERBOSE(MaxPool2D) << "PAD(B): " << param.padding.bottom << std::endl;
+  VERBOSE(MaxPool2D) << "PAD(L): " << param.padding.left << std::endl;
+  VERBOSE(MaxPool2D) << "PAD(R): " << param.padding.right << std::endl;
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto ofm_alloc = tensors->at(::internal::tflite::operand::Index{param.ofm_index}).get();
+    auto ifm_alloc = tensors->at(::internal::tflite::operand::Index{param.ifm_index}).get();
+
+    ::arm_compute::PoolingLayerInfo info{::arm_compute::PoolingType::MAX,
+                                         ::arm_compute::Size2D{param.kw, param.kh},
+                                         asPadStringInfo(param.padding, param.stride)};
+
+    std::unique_ptr<::arm_compute::CLPoolingLayer> fn{new ::arm_compute::CLPoolingLayer};
+
+    fn->configure(ifm_alloc, ofm_alloc, info);
+
+    builder.append(std::move(fn));
+  };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::AvgPool2D::implicit::Node &node)
+{
+  const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+  const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+
+  const ::internal::tflite::operand::Index kh_index{node.param().kh_index};
+  const ::internal::tflite::operand::Index kw_index{node.param().kw_index};
+
+  const ::internal::tflite::operand::Index vstride_index{node.param().vstride_index};
+  const ::internal::tflite::operand::Index hstride_index{node.param().hstride_index};
+
+  const ::internal::tflite::operand::Index padding_index{node.param().padding_index};
+
+  const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
+  const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
+
+  const int32_t kh = _ctx.at(kh_index).asScalar<int32_t>();
+  const int32_t kw = _ctx.at(kw_index).asScalar<int32_t>();
+
+  const int32_t vstride = _ctx.at(vstride_index).asScalar<int32_t>();
+  const int32_t hstride = _ctx.at(hstride_index).asScalar<int32_t>();
+
+  const PaddingCode padding_type =
+      static_cast<PaddingCode>(_ctx.at(padding_index).asScalar<int32_t>());
+
+  assert((ANEURALNETWORKS_PADDING_SAME == padding_type) ||
+         (ANEURALNETWORKS_PADDING_VALID == padding_type));
+
+  // Construct operation parameters
+  struct Param
+  {
+    int ofm_index;
+    int ifm_index;
+
+    uint32_t kw;
+    uint32_t kh;
+
+    Padding padding;
+    Stride stride;
+
+    // TODO Add 'activation' field
+  };
+
+  Param param;
+
+  param.ofm_index = ofm_index.asInt();
+  param.ifm_index = ifm_index.asInt();
+
+  param.kh = kh;
+  param.kw = kw;
+
+  param.stride.vertical = vstride;
+  param.stride.horizontal = hstride;
+
+  param.padding = (padding_type == ANEURALNETWORKS_PADDING_SAME)
+                      ? same_padding(ifm_shape, ofm_shape, param.stride, kw, kh)
+                      : valid_padding();
+
+  VERBOSE(AvgPool2D) << "IFM_H: " << ifm_shape.H << std::endl;
+  VERBOSE(AvgPool2D) << "IFM_W: " << ifm_shape.W << std::endl;
+  VERBOSE(AvgPool2D) << "OFM_H: " << ofm_shape.H << std::endl;
+  VERBOSE(AvgPool2D) << "OFM_W: " << ofm_shape.W << std::endl;
+  VERBOSE(AvgPool2D) << "KER_H: " << kh << std::endl;
+  VERBOSE(AvgPool2D) << "KER_W: " << kw << std::endl;
+  VERBOSE(AvgPool2D) << "STRIDE_H: " << vstride << std::endl;
+  VERBOSE(AvgPool2D) << "STRIDE_W: " << hstride << std::endl;
+  VERBOSE(AvgPool2D) << "PAD: " << to_string(padding_type) << std::endl;
+  VERBOSE(AvgPool2D) << "PAD(T): " << param.padding.top << std::endl;
+  VERBOSE(AvgPool2D) << "PAD(B): " << param.padding.bottom << std::endl;
+  VERBOSE(AvgPool2D) << "PAD(L): " << param.padding.left << std::endl;
+  VERBOSE(AvgPool2D) << "PAD(R): " << param.padding.right << std::endl;
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto ofm_alloc = tensors->at(::internal::tflite::operand::Index{param.ofm_index}).get();
+    auto ifm_alloc = tensors->at(::internal::tflite::operand::Index{param.ifm_index}).get();
+
+    ::arm_compute::PoolingLayerInfo info{
+        ::arm_compute::PoolingType::AVG, ::arm_compute::Size2D{param.kw, param.kh},
+        asPadStringInfo(param.padding, param.stride), true /* exclude_padding */};
+
+    std::unique_ptr<::arm_compute::CLPoolingLayer> fn{new ::arm_compute::CLPoolingLayer};
+
+    fn->configure(ifm_alloc, ofm_alloc, info);
+
+    builder.append(std::move(fn));
+  };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Concat::Node &node)
+{
+  throw std::runtime_error{"NYI - StageGenerator::generate for 'Concat'"};
+
+  return [](IExecutionBuilder &builder) {
+    // NOTE arm_compute does not have Concat operation
+    // TODO Implement
+  };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::FullyConnected::Node &node)
+{
+  const ::internal::tflite::operand::Index output_index{node.param().output_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+  const ::internal::tflite::operand::Index weight_index{node.param().weight_index};
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+  const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
+
+  // Construct operation parameters
+  struct Param
+  {
+    int output_index;
+
+    int input_index;
+    int weight_index;
+    int bias_index;
+
+    FuseCode activation;
+  };
+
+  Param param;
+
+  param.output_index = output_index.asInt();
+  param.input_index = input_index.asInt();
+  param.weight_index = weight_index.asInt();
+  param.bias_index = bias_index.asInt();
+
+  param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScalar<int32_t>());
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto output_alloc = tensors->at(::internal::tflite::operand::Index{param.output_index}).get();
+    auto input_alloc = tensors->at(::internal::tflite::operand::Index{param.input_index}).get();
+    auto weight_alloc = tensors->at(::internal::tflite::operand::Index{param.weight_index}).get();
+    auto bias_alloc = tensors->at(::internal::tflite::operand::Index{param.bias_index}).get();
+
+    auto fn = make_layer<::arm_compute::CLFullyConnectedLayer>();
+
+    fn->configure(input_alloc, weight_alloc, bias_alloc, output_alloc);
+
+    builder.append(std::move(fn));
+
+    ActivationBuilder{builder}.append(param.activation, output_alloc);
+  };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Reshape::Node &node)
+{
+  const ::internal::tflite::operand::Index output_index{node.param().output_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+
+  struct Param
+  {
+    int output_index;
+    int input_index;
+  };
+
+  Param param;
+
+  param.output_index = output_index.asInt();
+  param.input_index = input_index.asInt();
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto output_alloc = tensors->at(::internal::tflite::operand::Index{param.output_index}).get();
+    auto input_alloc = tensors->at(::internal::tflite::operand::Index{param.input_index}).get();
+
+    auto fn = make_layer<::arm_compute::CLReshapeLayer>();
+
+    fn->configure(input_alloc, output_alloc);
+
+    builder.append(std::move(fn));
+  };
+}
+
+Stage StageGenerator::generate(const ::internal::tflite::op::Softmax::Node &node)
+{
+  const ::internal::tflite::operand::Index output_index{node.param().output_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+
+  struct Param
+  {
+    int output_index;
+    int input_index;
+    float scale;
+  };
+
+  Param param;
+
+  param.output_index = output_index.asInt();
+  param.input_index = input_index.asInt();
+  // TODO Set scale correctly
+  param.scale = 1.0f;
+
+  auto tensors = _tensor_builder;
+
+  return [tensors, param](IExecutionBuilder &builder) {
+    auto output_alloc = tensors->at(::internal::tflite::operand::Index{param.output_index}).get();
+    auto input_alloc = tensors->at(::internal::tflite::operand::Index{param.input_index}).get();
+
+    auto fn = make_layer<::arm_compute::CLSoftmaxLayer>();
+
+    fn->configure(input_alloc, output_alloc, param.scale);
+
+    builder.append(std::move(fn));
+  };
+}
+
+} // namespace arm_compute
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/arm_compute/StageGenerator.h b/runtimes/new_runtime/src/internal/arm_compute/StageGenerator.h
new file mode 100644 (file)
index 0000000..740db9f
--- /dev/null
@@ -0,0 +1,38 @@
+#ifndef __INTERNAL_ARM_COMPUTE_STAGE_GENERATOR_H__
+#define __INTERNAL_ARM_COMPUTE_STAGE_GENERATOR_H__
+
+#include "internal/IStageGenerator.h"
+
+#include "internal/Model.h"
+#include "internal/arm_compute/TensorBuilder.h"
+
+namespace internal
+{
+namespace arm_compute
+{
+
+class StageGenerator : public ::internal::IStageGenerator
+{
+public:
+  StageGenerator(const ::internal::tflite::operand::Set &ctx,
+                 const std::shared_ptr<::internal::arm_compute::TensorBuilder> &tensor_builder);
+
+  virtual std::shared_ptr<ITensorBuilder> tensor_builder() override { return _tensor_builder; }
+
+  virtual Stage generate(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::MaxPool2D::implicit::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::AvgPool2D::implicit::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::Concat::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::FullyConnected::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::Reshape::Node &node) override;
+  virtual Stage generate(const ::internal::tflite::op::Softmax::Node &node) override;
+
+private:
+  const ::internal::tflite::operand::Set &_ctx;
+  std::shared_ptr<::internal::arm_compute::TensorBuilder> _tensor_builder;
+};
+
+} // namespace arm_compute
+} // namespace internal
+
+#endif // __INTERNAL_ARM_COMPUTE_STAGE_GENERATOR_H__
diff --git a/runtimes/new_runtime/src/internal/arm_compute/TensorBuilder.cc b/runtimes/new_runtime/src/internal/arm_compute/TensorBuilder.cc
new file mode 100644 (file)
index 0000000..6706b77
--- /dev/null
@@ -0,0 +1,58 @@
+#include "internal/arm_compute/TensorBuilder.h"
+
+#include <cassert>
+
+#include "internal/arm_compute.h"
+
+namespace internal
+{
+namespace arm_compute
+{
+
+TensorBuilder::TensorBuilder(::internal::arm_compute::Plan &plan) : _plan(plan)
+{
+  // DO NOTHING
+}
+
+void TensorBuilder::mark(const ::internal::tflite::operand::Index& ind)
+{
+  assert(_tensors.size() == 0);
+
+  _inds.insert(ind.asInt());
+}
+
+void TensorBuilder::prepare(const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx)
+{
+  assert(_tensors.size() == 0);
+
+  // TODO Handle SubTensor(subsumption)
+  //      Currently this TensorBuilder does not have subsumption info yet
+
+  for (auto ind_int : _inds)
+  {
+    ::internal::tflite::operand::Index ind{ind_int};
+    auto tensor = std::make_shared<::arm_compute::CLTensor>();
+    tensor->allocator()->init(tensor_info_ctx.at(ind.asInt()));
+    _plan.operands().set(ind, std::make_shared<::internal::arm_compute::operand::Object>(tensor));
+    _tensors[ind.asInt()] = tensor;
+  }
+}
+
+void TensorBuilder::allocate(void)
+{
+  assert(_inds.size() == _tensors.size());
+
+  for (const auto &tensor_entry : _tensors)
+  {
+    auto tensor = tensor_entry.second;
+    tensor->allocator()->allocate();
+  }
+}
+
+std::shared_ptr<::arm_compute::CLTensor> TensorBuilder::at(const ::internal::tflite::operand::Index &ind)
+{
+  return _tensors.at(ind.asInt());
+}
+
+} // namespace arm_compute
+} // namespace internal
diff --git a/runtimes/new_runtime/src/internal/arm_compute/TensorBuilder.h b/runtimes/new_runtime/src/internal/arm_compute/TensorBuilder.h
new file mode 100644 (file)
index 0000000..c8c28f0
--- /dev/null
@@ -0,0 +1,38 @@
+#ifndef __INTERNAL_ARM_COMPUTE_TENSOR_BUILDER_H__
+#define __INTERNAL_ARM_COMPUTE_TENSOR_BUILDER_H__
+
+#include "internal/ITensorBuilder.h"
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include <arm_compute/runtime/CL/CLTensor.h>
+
+namespace internal
+{
+namespace arm_compute
+{
+
+class Plan;
+
+class TensorBuilder : public ::internal::ITensorBuilder
+{
+public:
+  TensorBuilder(::internal::arm_compute::Plan &plan);
+
+  virtual void mark(const ::internal::tflite::operand::Index& ind) override;
+  virtual void prepare(const std::map<int, ::arm_compute::TensorInfo> &tensor_info_ctx) override;
+  virtual void allocate(void) override;
+
+  std::shared_ptr<::arm_compute::CLTensor> at(const ::internal::tflite::operand::Index &ind);
+
+private:
+  ::internal::arm_compute::Plan &_plan;
+  std::unordered_set<int> _inds;
+  std::unordered_map<int, std::shared_ptr<::arm_compute::CLTensor>> _tensors;
+};
+
+} // namespace arm_compute
+} // namespace internal
+
+#endif // __INTERNAL_ARM_COMPUTE_TENSOR_BUILDER_H__