From 3ce653ef637a3bbbdab4f426d588fe447776dac1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Shubham=20Gupta/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 19 Sep 2018 07:02:34 +0530 Subject: [PATCH] Implementation for LRN Layer (#2643) This patch implements Local Response Normalization Layer in NNFW-PACL Signed-off-by: shubham --- runtimes/pure_arm_compute/src/compilation.cc | 67 ++++++++++++++++++++++ .../src/internal/op/LocalResponseNormalization.cc | 48 ++++++++++++++++ .../src/internal/op/LocalResponseNormalization.h | 57 ++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 12 ++++ 5 files changed, 186 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index ade2809..5da0b79 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -513,6 +513,7 @@ public: void visit(const ::internal::tflite::op::HashtableLookup::Node &node) override; void visit(const ::internal::tflite::op::L2Normalization::Node &node) override; void visit(const ::internal::tflite::op::SquaredDifference::Node &node) override; + void visit(const ::internal::tflite::op::LocalResponseNormalization::Node &node) override; private: const ::internal::tflite::operand::Set &_ctx; @@ -3898,6 +3899,72 @@ void Planner::visit(const ::internal::tflite::op::HashtableLookup::Node &node) throw std::runtime_error("Not supported"); } +void Planner::visit(const ::internal::tflite::op::LocalResponseNormalization::Node &node) +{ + const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; + const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index}; + const ::internal::tflite::operand::Index radius_index{node.param().radius_index}; + const ::internal::tflite::operand::Index bias_index{node.param().bias_index}; + const ::internal::tflite::operand::Index alpha_index{node.param().alpha_index}; + const ::internal::tflite::operand::Index beta_index{node.param().beta_index}; + + // Set Shape Constraints and TensorInfo + _builder.addShapeConstr( + ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(), + _ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint())); + _builder.addShapeConstr( + ofm_index, asTensorInfo(asTensorShape(_ctx.at(ofm_index).shape()), _ctx.at(ofm_index).type(), + _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint())); + + // Construct operation parameters + struct Param + { + int ofm_index; + int ifm_index; + int32_t radius; + float bias; + float alpha; + float beta; + }; + + Param param; + + param.ofm_index = ofm_index.asInt(); + param.ifm_index = ifm_index.asInt(); + + param.radius = _ctx.at(radius_index).asScalar(); + param.alpha = _ctx.at(alpha_index).asScalar(); + param.beta = _ctx.at(beta_index).asScalar(); + param.bias = _ctx.at(bias_index).asScalar(); + + auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) { + auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index}); + auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index}); + + const auto norm_info = + ::arm_compute::NormalizationLayerInfo(::arm_compute::NormType::CROSS_MAP, param.radius, + param.alpha, param.beta, param.bias, false); + if (::internal::arm_compute::isGpuMode()) + { + auto fn = nnfw::make_unique<::arm_compute::CLNormalizationLayer>(); + + fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc), norm_info); + + builder.append("LocalResponseNormalization", std::move(fn)); + } + else + { + auto fn = nnfw::make_unique<::arm_compute::NENormalizationLayer>(); + + fn->configure(ifm_alloc, ofm_alloc, norm_info); + + builder.append("LocalResponseNormalization", std::move(fn)); + } + }; + + _builder.addStage(stage); +} + class AllocationContext final : public IAllocationContext { public: diff --git a/runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.cc b/runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.cc new file mode 100644 index 0000000..3d0aa6a --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.cc @@ -0,0 +1,48 @@ +#include "internal/op/LocalResponseNormalization.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace LocalResponseNormalization +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace LocalResponseNormalization +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace LocalResponseNormalization +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 5 && outputCount == 1); + + ofm_index = outputs[0]; + + ifm_index = inputs[0]; + radius_index = inputs[1]; + bias_index = inputs[2]; + alpha_index = inputs[3]; + beta_index = inputs[4]; +} + +} // namespace LocalResponseNormalization +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.h b/runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.h new file mode 100644 index 0000000..53961dc --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/LocalResponseNormalization.h @@ -0,0 +1,57 @@ +#ifndef __INTERNAL_OP_LOCAL_RESPONSE_NORMALIZATION_H__ +#define __INTERNAL_OP_LOCAL_RESPONSE_NORMALIZATION_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace LocalResponseNormalization +{ + +struct Param +{ + int32_t ofm_index; + + int32_t ifm_index; + int32_t radius_index; + int32_t bias_index; + int32_t alpha_index; + int32_t beta_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace LocalResponseNormalization +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_LOCAL_RESPONSE_NORMALIZATION_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 6d8d10a..56585a3 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -56,6 +56,7 @@ #include "internal/op/HashtableLookup.h" #include "internal/op/L2Normalization.h" #include "internal/op/SquaredDifference.h" +#include "internal/op/LocalResponseNormalization.h" namespace internal { @@ -112,6 +113,7 @@ struct NodeVisitor virtual void visit(const HashtableLookup::Node &) = 0; virtual void visit(const L2Normalization::Node &) = 0; virtual void visit(const SquaredDifference::Node &) = 0; + virtual void visit(const LocalResponseNormalization::Node &) = 0; }; } // namespace op diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 49ea59f..3a7db99 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -675,6 +675,18 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION: + { + + using internal::tflite::op::LocalResponseNormalization::Param; + using internal::tflite::op::LocalResponseNormalization::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } default: throw std::runtime_error{"Not supported operation"}; }; -- 2.7.4