From 7d8e634aedbac57653521b4daaee37636d999e57 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=98=A4=ED=98=95=EC=84=9D/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Staff=20Engineer/=EC=82=BC=EC=84=B1?= =?utf8?q?=EC=A0=84=EC=9E=90?= Date: Fri, 1 Jun 2018 15:29:18 +0900 Subject: [PATCH] Define Max namespace in pure acl runtime (#1506) Define Max namespace, node, param in pure acl runtime Signed-off-by: Hyeongseok Oh --- runtimes/pure_arm_compute/src/compilation.cc | 7 +++ runtimes/pure_arm_compute/src/internal/op/Max.cc | 49 ++++++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Max.h | 54 ++++++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 12 +++++ 5 files changed, 124 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/Max.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/Max.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 6d7b08e..6080463 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -277,6 +277,7 @@ public: void visit(const ::internal::tflite::op::ResizeBilinear::Node &node) override; void visit(const ::internal::tflite::op::Reshape::Node &node) override; void visit(const ::internal::tflite::op::Softmax::Node &node) override; + void visit(const ::internal::tflite::op::Max::Node &node) override; private: const ::internal::tflite::operand::Set &_ctx; @@ -1062,6 +1063,12 @@ void Planner::visit(const ::internal::tflite::op::Softmax::Node &node) _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::Max::Node &node) +{ + VERBOSE(Max) << "Configure MAX operation" << std::endl; + throw std::runtime_error{"Not supported operation"}; +} + class AllocationContext final : public IAllocationContext { public: diff --git a/runtimes/pure_arm_compute/src/internal/op/Max.cc b/runtimes/pure_arm_compute/src/internal/op/Max.cc new file mode 100644 index 0000000..8ac26db --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Max.cc @@ -0,0 +1,49 @@ +#include "internal/op/Max.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Max +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace Max +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Max +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 2 && outputCount == 1); + + ofm_index = outputs[0]; + + // Each input should be interpreted as follows: + // + // 0 -> Input Tensor Index + // 1 -> Axis Tensor Index + ifm_index = inputs[0]; + axis_index = inputs[1]; +} + +} // namespace Max +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/Max.h b/runtimes/pure_arm_compute/src/internal/op/Max.h new file mode 100644 index 0000000..954625f --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Max.h @@ -0,0 +1,54 @@ +#ifndef __INTERNAL_OP_MAX_H__ +#define __INTERNAL_OP_MAX_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Max +{ + +struct Param +{ + int32_t ofm_index; + + int32_t ifm_index; + int32_t axis_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace Max +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_MAX_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 0295bce..e547ea2 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -11,6 +11,7 @@ #include "internal/op/ResizeBilinear.h" #include "internal/op/FullyConnected.h" #include "internal/op/Softmax.h" +#include "internal/op/Max.h" namespace internal { @@ -33,6 +34,7 @@ struct NodeVisitor virtual void visit(const ResizeBilinear::Node &) = 0; virtual void visit(const FullyConnected::Node &) = 0; virtual void visit(const Softmax::Node &) = 0; + virtual void visit(const Max::Node &) = 0; }; } // namespace op diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 4b0a83b..9942f73 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -252,6 +252,18 @@ int ANeuralNetworksModel_addOperationEx(ANeuralNetworksModel *model, { switch (type) { + case ANEURALNETWORKS_TENSORFLOW_MAX_EX: + { + using internal::tflite::op::Max::Param; + using internal::tflite::op::Max::Node; + + // Add 'operations' + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } default: throw std::runtime_error{"Not supported operation"}; } -- 2.7.4