From ab55b8380985d2654aef32f227ec25114b6cbee5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vishal=20Keshav/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 18 Oct 2018 17:12:09 +0530 Subject: [PATCH] Adding Node for pack operation (#3193) Adding operation node for pack operation Signed-off-by: Vishal keshav --- runtimes/pure_arm_compute/src/compilation.cc | 8 ++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/internal/op/Pack.cc | 49 +++++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Pack.h | 56 ++++++++++++++++++++++ runtimes/pure_arm_compute/src/model.cc | 11 +++++ 5 files changed, 126 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/Pack.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/Pack.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 55b59d0..e66321d 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -552,6 +552,7 @@ public: void visit(const ::internal::tflite::op::ReduceSum::Node &node) override; void visit(const ::internal::tflite::op::Equal::Node &node) override; void visit(const ::internal::tflite::op::TransposeConv::Node &node) override; + void visit(const ::internal::tflite::op::Pack::Node &node) override; void visit(const ::internal::tflite::op::Abs::Node &node) override; private: @@ -4390,6 +4391,13 @@ void Planner::visit(const ::internal::tflite::op::DepthToSpace::Node &node) _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::Pack::Node &node) +{ + VERBOSE(Pack) << "Configure Pack operation" << std::endl; + + throw std::runtime_error("Not supported, yet"); +} + void Planner::visit(const ::internal::tflite::op::Neg::Node &node) { const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 3ccd97e..d97bcad 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -71,6 +71,7 @@ #include "internal/op/Equal.h" #include "internal/op/BatchToSpaceNd.h" #include "internal/op/TransposeConv.h" +#include "internal/op/Pack.h" #include "internal/op/Abs.h" namespace internal @@ -363,6 +364,7 @@ struct NodeVisitor virtual void visit(const Equal::Node &) = 0; virtual void visit(const BatchToSpaceNd::Node &) = 0; virtual void visit(const TransposeConv::Node &) = 0; + virtual void visit(const Pack::Node &) = 0; virtual void visit(const Abs::Node &) = 0; }; diff --git a/runtimes/pure_arm_compute/src/internal/op/Pack.cc b/runtimes/pure_arm_compute/src/internal/op/Pack.cc new file mode 100644 index 0000000..c5ba8eb --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Pack.cc @@ -0,0 +1,49 @@ + +#include "internal/op/Pack.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Pack +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace Pack +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Pack +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(outputCount == 1); + + ofm_index = outputs[0]; + // last input is axis along which packing is required + for (uint32_t n = 0; n < inputCount - 1; ++n) + { + ifm_indexes.emplace_back(inputs[n]); + } + axis_index = inputs[inputCount - 1]; +} + +} // namespace Pack +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/Pack.h b/runtimes/pure_arm_compute/src/internal/op/Pack.h new file mode 100644 index 0000000..a1dc913 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Pack.h @@ -0,0 +1,56 @@ +#ifndef __INTERNAL_OP_PACK_H__ +#define __INTERNAL_OP_PACK_H__ + +#include "internal/op/Node.h" +#include + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Pack +{ + +struct Param +{ + int32_t ofm_index; + // There are N+1 inputs, 0 to N-1 are tensors of same shape + // Nth input is axis index along which stack is needed to be done. + std::vector ifm_indexes; + int32_t axis_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace Pack +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_PACK_H__ diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index b179658..2853eaf 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -882,6 +882,17 @@ int ANeuralNetworksModel_addOperationEx(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_PACK_EX: + { + using internal::tflite::op::Pack::Param; + using internal::tflite::op::Pack::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } case ANEURALNETWORKS_ABS_EX: { using internal::tflite::op::Abs::Param; -- 2.7.4