From 5d38732b23bc8076890286c04033236a8a2aea1b Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Tue, 5 Jun 2018 16:28:59 +0900 Subject: [PATCH] Support Cast operation for pureACL (#1569) This commit supports Cast operation for pureACL. - support adding Cast operation via API related to model. - The plan of Cast operation is not yet implemented. Signed-off-by: jiseob.jang --- runtimes/pure_arm_compute/src/compilation.cc | 6 +++ runtimes/pure_arm_compute/src/internal/op/Cast.cc | 46 +++++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Cast.h | 53 ++++++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 12 +++++ 5 files changed, 119 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/Cast.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/Cast.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 3037762..fdd6e8e 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -282,6 +282,7 @@ public: void visit(const ::internal::tflite::op::Softmax::Node &node) override; void visit(const ::internal::tflite::op::StridedSlice::Node &node) override; void visit(const ::internal::tflite::op::ReduceMax::Node &node) override; + void visit(const ::internal::tflite::op::Cast::Node &node) override; void visit(const ::internal::tflite::op::TopKV2::Node &node) override; void visit(const ::internal::tflite::op::Gather::Node &node) override; @@ -1323,6 +1324,11 @@ void Planner::visit(const ::internal::tflite::op::ReduceMax::Node &node) throw std::runtime_error{"ReduceMax: Not supported operation"}; } +void Planner::visit(const ::internal::tflite::op::Cast::Node &node) +{ + // TODO Implement the plan of Cast +} + void Planner::visit(const ::internal::tflite::op::TopKV2::Node &node) { const ::internal::tflite::operand::Index outputValues_index{node.param().outputValues_index}; diff --git a/runtimes/pure_arm_compute/src/internal/op/Cast.cc b/runtimes/pure_arm_compute/src/internal/op/Cast.cc new file mode 100644 index 0000000..4bcbbe0 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Cast.cc @@ -0,0 +1,46 @@ +#include "internal/op/Cast.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Cast +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace Cast +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Cast +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 1 && outputCount == 1); + + output_index = outputs[0]; + + // Each input should be interpreted as follows: + // 0 -> input Tensor Index + input_index = inputs[0]; +} + +} // namespace Cast +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/Cast.h b/runtimes/pure_arm_compute/src/internal/op/Cast.h new file mode 100644 index 0000000..cc5c4e1 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Cast.h @@ -0,0 +1,53 @@ +#ifndef __INTERNAL_OP_CAST_H__ +#define __INTERNAL_OP_CAST_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace Cast +{ + +struct Param +{ + int32_t output_index; + + int32_t input_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace Cast +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_Cast_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index dc8a6a6..c44ed12 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -14,6 +14,7 @@ #include "internal/op/FullyConnected.h" #include "internal/op/Softmax.h" #include "internal/op/ReduceMax.h" +#include "internal/op/Cast.h" #include "internal/op/TopKV2.h" #include "internal/op/Gather.h" @@ -41,6 +42,7 @@ struct NodeVisitor virtual void visit(const FullyConnected::Node &) = 0; virtual void visit(const Softmax::Node &) = 0; virtual void visit(const ReduceMax::Node &) = 0; + virtual void visit(const Cast::Node &) = 0; virtual void visit(const TopKV2::Node &) = 0; virtual void visit(const Gather::Node &) = 0; }; diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 4cf6d3f..74f8107 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -268,6 +268,18 @@ int ANeuralNetworksModel_addOperationEx(ANeuralNetworksModel *model, { switch (type) { + case ANEURALNETWORKS_CAST_EX: + { + using internal::tflite::op::Cast::Param; + using internal::tflite::op::Cast::Node; + + // Add 'operations' + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } case ANEURALNETWORKS_TENSORFLOW_MAX_EX: { using internal::tflite::op::ReduceMax::Param; -- 2.7.4