From 3d9cff194a245be06559d2321bfea017bd2cc426 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EA=B9=80=EC=9A=A9=EC=84=AD/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Fri, 7 Sep 2018 10:07:09 +0900 Subject: [PATCH] Introduce EMBEDDING_LOOKUP op to pure_acl (#2631) Introduce EMBEDDING_LOOKUP op to pure_acl Signed-off-by: Yongseop Kim --- runtimes/pure_arm_compute/src/compilation.cc | 7 +++ .../src/internal/op/EmbeddingLookup.cc | 49 ++++++++++++++++++++ .../src/internal/op/EmbeddingLookup.h | 54 ++++++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 14 ++++++ 5 files changed, 126 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 71c57d2..9f17014 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -481,6 +481,7 @@ public: void visit(const ::internal::tflite::op::Pad::Node &node) override; void visit(const ::internal::tflite::op::SpaceToDepth::Node &node) override; void visit(const ::internal::tflite::op::L2Pool2D::Implicit::Node &node) override; + void visit(const ::internal::tflite::op::EmbeddingLookup::Node &node) override; private: const ::internal::tflite::operand::Set &_ctx; @@ -3505,6 +3506,12 @@ void Planner::visit(const ::internal::tflite::op::L2Pool2D::Implicit::Node &node _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::EmbeddingLookup::Node &node) +{ + // TODO Implement EMBEDDING_LOOKUP + throw std::runtime_error("Not supported"); +} + class AllocationContext final : public IAllocationContext { public: diff --git a/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.cc b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.cc new file mode 100644 index 0000000..2f658a1 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.cc @@ -0,0 +1,49 @@ +#include "internal/op/EmbeddingLookup.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace EmbeddingLookup +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace EmbeddingLookup +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace EmbeddingLookup +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 2 && outputCount == 1); + + output_index = outputs[0]; + + // Each input should be interpreted as follows: + // + // 0 -> Lookups Index + // 1 -> Values Index + lookups_index = inputs[0]; + values_index = inputs[1]; +} + +} // namespace EmbeddingLookup +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.h b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.h new file mode 100644 index 0000000..2c9e2bb --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.h @@ -0,0 +1,54 @@ +#ifndef __INTERNAL_OP_EMBEDDING_LOOKUP_H__ +#define __INTERNAL_OP_EMBEDDING_LOOKUP_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace EmbeddingLookup +{ + +struct Param +{ + int32_t output_index; + + int32_t lookups_index; + int32_t values_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace EmbeddingLookup +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_EMBEDDING_LOOKUP_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 7c6d326..3c9ce33 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -36,6 +36,7 @@ #include "internal/op/Pad.h" #include "internal/op/SpaceToDepth.h" #include "internal/op/L2Pool2D.h" +#include "internal/op/EmbeddingLookup.h" namespace internal { @@ -87,6 +88,7 @@ struct NodeVisitor virtual void visit(const Pad::Node &) = 0; virtual void visit(const SpaceToDepth::Node &) = 0; virtual void visit(const L2Pool2D::Implicit::Node &) = 0; + virtual void visit(const EmbeddingLookup::Node &) = 0; }; } // namespace op diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 43d770a..500b103 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -611,6 +611,20 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_EMBEDDING_LOOKUP: + { + assert(inputCount == 2); + assert(outputCount == 1); + + using internal::tflite::op::EmbeddingLookup::Param; + using internal::tflite::op::EmbeddingLookup::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } default: throw std::runtime_error{"Not supported operation"}; }; -- 2.7.4