From 7f1440bfa4692378b2459ff20c05923905890106 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Senior=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Thu, 24 May 2018 12:55:44 +0900 Subject: [PATCH] [Pure CL Runtime] Support 'ADD' operation (#1318) This commit implements 'ADD' operation support in pure CL runtime. Signed-off-by: Jonghyun Park --- runtimes/pure_arm_compute/src/compilation.cc | 57 ++++++++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Add.cc | 51 +++++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Add.h | 55 +++++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 15 ++++++ 5 files changed, 180 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/Add.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/Add.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index c95757b..8c4dd4d 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -245,6 +246,7 @@ public: } public: + void visit(const ::internal::tflite::op::Add::Node &node) override; void visit(const ::internal::tflite::op::Conv2D::implicit::Node &node) override; void visit(const ::internal::tflite::op::MaxPool2D::implicit::Node &node) override; void visit(const ::internal::tflite::op::AvgPool2D::implicit::Node &node) override; @@ -258,6 +260,61 @@ private: IPlanBuilder &_builder; }; +void Planner::visit(const ::internal::tflite::op::Add::Node &node) +{ + const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; + + const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index}; + const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index}; + + const ::internal::tflite::operand::Index activation_index{node.param().activation_index}; + + // TODO Support generic tensor shape + const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature(); + const auto lhs_shape = _ctx.at(lhs_index).shape().asFeature(); + const auto rhs_shape = _ctx.at(rhs_index).shape().asFeature(); + + // Set Shape Constraints + _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape)); + _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape)); + _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape)); + + // Construct operation parameters + struct Param + { + int ofm_index; + int lhs_index; + int rhs_index; + + FuseCode activation; + }; + + Param param; + + param.ofm_index = ofm_index.asInt(); + param.lhs_index = lhs_index.asInt(); + param.rhs_index = rhs_index.asInt(); + + param.activation = static_cast(_ctx.at(activation_index).asScala()); + + auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) { + auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index}); + auto lhs_alloc = ctx.at(::internal::tflite::operand::Index{param.lhs_index}); + auto rhs_alloc = ctx.at(::internal::tflite::operand::Index{param.rhs_index}); + + auto fn = make_layer<::arm_compute::CLArithmeticAddition>(); + + // TODO Decide ConvertPolicy (WARP? SATURATE?) according to NN API specification + fn->configure(lhs_alloc, rhs_alloc, ofm_alloc, ::arm_compute::ConvertPolicy::SATURATE); + + builder.append(std::move(fn)); + + ActivationBuilder{builder}.append(param.activation, ofm_alloc); + }; + + _builder.addStage(stage); +} + void Planner::visit(const ::internal::tflite::op::Conv2D::implicit::Node &node) { const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; diff --git a/runtimes/pure_arm_compute/src/internal/op/Add.cc b/runtimes/pure_arm_compute/src/internal/op/Add.cc new file mode 100644 index 0000000..87ba33e --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Add.cc @@ -0,0 +1,51 @@ +#include "internal/op/Add.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Add +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace Add +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Add +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 3 && outputCount == 1); + + ofm_index = outputs[0]; + + // Each input should be interpreted as follows: + // + // 0 -> LHS Tensor Index + // 1 -> RHS Tensor Index + // 2 -> Activation Index + lhs_index = inputs[0]; + rhs_index = inputs[1]; + activation_index = inputs[2]; +} + +} // namespace Add +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/Add.h b/runtimes/pure_arm_compute/src/internal/op/Add.h new file mode 100644 index 0000000..ac62fa5 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Add.h @@ -0,0 +1,55 @@ +#ifndef __INTERNAL_OP_ADD_H__ +#define __INTERNAL_OP_ADD_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Add +{ + +struct Param +{ + int32_t ofm_index; + + int32_t lhs_index; + int32_t rhs_index; + int32_t activation_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace Add +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_ADD_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index e44467e..c787be7 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -1,6 +1,7 @@ #ifndef __INTERNAL_OP_NODE_VISITOR_H__ #define __INTERNAL_OP_NODE_VISITOR_H__ +#include "internal/op/Add.h" #include "internal/op/Conv2D.h" #include "internal/op/MaxPool2D.h" #include "internal/op/AvgPool2D.h" @@ -20,6 +21,7 @@ struct NodeVisitor { virtual ~NodeVisitor() = default; + virtual void visit(const Add::Node &) = 0; virtual void visit(const Conv2D::implicit::Node &) = 0; virtual void visit(const MaxPool2D::implicit::Node &) = 0; virtual void visit(const AvgPool2D::implicit::Node &) = 0; diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index ada6cee..e2f77ab 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -76,6 +76,21 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model, { switch (type) { + case ANEURALNETWORKS_ADD: + { + assert(inputCount == 3); + assert(outputCount == 1); + + using internal::tflite::op::Add::Param; + using internal::tflite::op::Add::Node; + + // Add 'operations' + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } case ANEURALNETWORKS_CONV_2D: { // inputCount is either 7 or 9 acccording to NN API specification. -- 2.7.4