[PureCL] Support DIV operation (#1482)
author서상민/동작제어Lab(SR)/Staff Engineer/삼성전자 <sangmin7.seo@samsung.com>
Fri, 1 Jun 2018 00:38:50 +0000 (09:38 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Fri, 1 Jun 2018 00:38:50 +0000 (09:38 +0900)
* [PureCL] Support DIV operation

For #1338 and #1366

This patch adds DIV operation to the pure ACL runtime.

Signed-off-by: Sangmin Seo <sangmin7.seo@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc
runtimes/pure_arm_compute/src/internal/op/Div.cc [new file with mode: 0644]
runtimes/pure_arm_compute/src/internal/op/Div.h [new file with mode: 0644]
runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h
runtimes/pure_arm_compute/src/model.cc

index 46846ce..6d7b08e 100644 (file)
@@ -6,6 +6,7 @@
 #include <arm_compute/runtime/CL/CLScheduler.h>
 #include <arm_compute/runtime/CL/CLSubTensor.h>
 #include <arm_compute/runtime/CL/functions/CLArithmeticAddition.h>
+#include <arm_compute/runtime/CL/functions/CLPixelWiseDivision.h>
 #include <arm_compute/runtime/CL/functions/CLPoolingLayer.h>
 #include <arm_compute/runtime/CL/functions/CLActivationLayer.h>
 #include <arm_compute/runtime/CL/functions/CLScale.h>
@@ -267,6 +268,7 @@ public:
 
 public:
   void visit(const ::internal::tflite::op::Add::Node &node) override;
+  void visit(const ::internal::tflite::op::Div::Node &node) override;
   void visit(const ::internal::tflite::op::Conv2D::implicit::Node &node) override;
   void visit(const ::internal::tflite::op::MaxPool2D::implicit::Node &node) override;
   void visit(const ::internal::tflite::op::AvgPool2D::implicit::Node &node) override;
@@ -359,6 +361,109 @@ void Planner::visit(const ::internal::tflite::op::Add::Node &node)
   _builder.addStage(stage);
 }
 
+void Planner::visit(const ::internal::tflite::op::Div::Node &node)
+{
+  const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+
+  const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index};
+  const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index};
+
+  const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
+
+  // TODO Support general broadcasting. Currently, broadcast works only when one operand is scalar
+  //      or the operand's dimension size is one.
+  const auto ofm_shape = _ctx.at(ofm_index).shape();
+  const auto ofm_shape_rank = ofm_shape.rank();
+  if (ofm_shape_rank == 4)
+  {
+    _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape.asFeature()));
+  }
+  else if (ofm_shape_rank == 1)
+  {
+    _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape.asVector()));
+  }
+  else
+  {
+    throw std::runtime_error("Not supported, yet");
+  }
+
+  const auto lhs_shape = _ctx.at(lhs_index).shape();
+  const auto lhs_shape_rank = lhs_shape.rank();
+  if (lhs_shape_rank == 4)
+  {
+    _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape.asFeature()));
+  }
+  else if (lhs_shape_rank == 1)
+  {
+    _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape.asVector()));
+  }
+  else if (lhs_shape_rank == 0)
+  {
+    // scalar
+    _builder.addShapeConstr(lhs_index, asTensorInfo(1));
+  }
+  else
+  {
+    throw std::runtime_error("Not supported, yet");
+  }
+
+  const auto rhs_shape = _ctx.at(rhs_index).shape();
+  const auto rhs_shape_rank = rhs_shape.rank();
+  if (rhs_shape_rank == 4)
+  {
+    _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape.asFeature()));
+  }
+  else if (rhs_shape_rank == 1)
+  {
+    _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape.asVector()));
+  }
+  else if (rhs_shape_rank == 0)
+  {
+    // scalar
+    _builder.addShapeConstr(rhs_index, asTensorInfo(1));
+  }
+  else
+  {
+    throw std::runtime_error("Not supported, yet");
+  }
+
+  // Construct operation parameters
+  struct Param
+  {
+    int ofm_index;
+    int lhs_index;
+    int rhs_index;
+
+    FuseCode activation;
+  };
+
+  Param param;
+
+  param.ofm_index = ofm_index.asInt();
+  param.lhs_index = lhs_index.asInt();
+  param.rhs_index = rhs_index.asInt();
+
+  param.activation = static_cast<FuseCode>(_ctx.at(activation_index).asScala<int32_t>());
+
+  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+    auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
+    auto lhs_alloc = ctx.at(::internal::tflite::operand::Index{param.lhs_index});
+    auto rhs_alloc = ctx.at(::internal::tflite::operand::Index{param.rhs_index});
+
+    auto fn = make_layer<::arm_compute::CLPixelWiseDivision>();
+
+    // TODO Decide scale, overflow_policy, and rounding_policy.
+    //      Currently, the default values are used.
+    fn->configure(lhs_alloc, rhs_alloc, ofm_alloc);
+
+    builder.append(std::move(fn));
+
+    ActivationBuilder{builder}.append(param.activation, ofm_alloc);
+  };
+
+  _builder.addStage(stage);
+}
+
 void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node)
 {
   const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
@@ -1077,6 +1182,8 @@ void PlanBuilder::addStage(const Stage &stage) { _stages.emplace_back(stage); }
 
 #include <stack>
 
+using namespace std::placeholders;
+
 static void initFeatureTensor(::arm_compute::ITensor &tensor,
                               const nnfw::util::feature::Shape &feature_shape,
                               const uint8_t *feature_base, const size_t feature_size)
@@ -1225,26 +1332,28 @@ void PlanBuilder::finalize(void) const
     {
       const auto rank = operands.at(operand_idx).shape().rank();
       auto base = operands.at(operand_idx).data().base();
-      ::arm_compute::ICLTensor &tensor = *(_plan.operands().at(operand_idx).ptr());
 
       switch (rank)
       {
         case 0: // scalar
         {
-          initVectorTensor(tensor, base, 1);
+          auto initializer = std::bind(initVectorTensor, _1, base, 1);
+          _plan.operands().at(operand_idx).access(initializer);
           break;
         }
         case 1: // vector
         {
           auto size = operands.at(operand_idx).shape().asVector();
-          initVectorTensor(tensor, base, size);
+          auto initializer = std::bind(initVectorTensor, _1, base, size);
+          _plan.operands().at(operand_idx).access(initializer);
           break;
         }
         case 4: // feature
         {
           const auto feature_shape = operands.at(operand_idx).shape().asFeature();
           auto size = operands.at(operand_idx).data().size();
-          initFeatureTensor(tensor, feature_shape, base, size);
+          auto initializer = std::bind(initFeatureTensor, _1, feature_shape, base, size);
+          _plan.operands().at(operand_idx).access(initializer);
           break;
         }
         default:
diff --git a/runtimes/pure_arm_compute/src/internal/op/Div.cc b/runtimes/pure_arm_compute/src/internal/op/Div.cc
new file mode 100644 (file)
index 0000000..9de2222
--- /dev/null
@@ -0,0 +1,51 @@
+#include "internal/op/Div.h"
+#include "internal/op/NodeVisitor.h"
+
+#include <cassert>
+
+namespace internal
+{
+namespace tflite
+{
+namespace op
+{
+namespace Div
+{
+
+void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
+
+} // namespace Div
+} // namespace op
+} // namespace tflite
+} // namespace internal
+
+namespace internal
+{
+namespace tflite
+{
+namespace op
+{
+namespace Div
+{
+
+Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount,
+             const uint32_t *outputs)
+{
+  assert(inputCount == 3 && outputCount == 1);
+
+  ofm_index = outputs[0];
+
+  // Each input should be interpreted as follows:
+  //
+  //  0 -> LHS Tensor Index
+  //  1 -> RHS Tensor Index
+  //  2 -> Activation Index
+  lhs_index = inputs[0];
+  rhs_index = inputs[1];
+  activation_index = inputs[2];
+}
+
+} // namespace Div
+} // namespace op
+} // namespace tflite
+} // namespace internal
diff --git a/runtimes/pure_arm_compute/src/internal/op/Div.h b/runtimes/pure_arm_compute/src/internal/op/Div.h
new file mode 100644 (file)
index 0000000..2b4b87f
--- /dev/null
@@ -0,0 +1,55 @@
+#ifndef __INTERNAL_OP_DIV_H__
+#define __INTERNAL_OP_DIV_H__
+
+#include "internal/op/Node.h"
+
+#include <cstdint>
+
+namespace internal
+{
+namespace tflite
+{
+namespace op
+{
+namespace Div
+{
+
+struct Param
+{
+  int32_t ofm_index;
+
+  int32_t lhs_index;
+  int32_t rhs_index;
+  int32_t activation_index;
+
+  Param() = default;
+  Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs);
+};
+
+class Node final : public op::Node
+{
+public:
+  Node(const Param &param) : _param(param)
+  {
+    // DO NOTHING
+  }
+
+public:
+  virtual ~Node() = default;
+
+public:
+  const Param &param(void) const { return _param; }
+
+public:
+  void accept(NodeVisitor &&) const override;
+
+private:
+  const Param _param;
+};
+
+} // namespace Div
+} // namespace op
+} // namespace tflite
+} // namespace internal
+
+#endif // __INTERNAL_OP_DIV_H__
index 8a5829a..0295bce 100644 (file)
@@ -2,6 +2,7 @@
 #define __INTERNAL_OP_NODE_VISITOR_H__
 
 #include "internal/op/Add.h"
+#include "internal/op/Div.h"
 #include "internal/op/Conv2D.h"
 #include "internal/op/MaxPool2D.h"
 #include "internal/op/AvgPool2D.h"
@@ -23,6 +24,7 @@ struct NodeVisitor
   virtual ~NodeVisitor() = default;
 
   virtual void visit(const Add::Node &) = 0;
+  virtual void visit(const Div::Node &) = 0;
   virtual void visit(const Conv2D::implicit::Node &) = 0;
   virtual void visit(const MaxPool2D::implicit::Node &) = 0;
   virtual void visit(const AvgPool2D::implicit::Node &) = 0;
index 7e59d12..83fd27b 100644 (file)
@@ -88,6 +88,21 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model,
 
       break;
     }
+    case ANEURALNETWORKS_DIV:
+    {
+      assert(inputCount == 3);
+      assert(outputCount == 1);
+
+      using internal::tflite::op::Div::Param;
+      using internal::tflite::op::Div::Node;
+
+      // Add 'operations'
+      auto &operations = model->deref().operations();
+
+      operations.emplace_back<Node>(Param{inputCount, inputs, outputCount, outputs});
+
+      break;
+    }
     case ANEURALNETWORKS_CONV_2D:
     {
       // inputCount is either 7 or 9 acccording to NN API specification.