From: 김용섭/동작제어Lab(SR)/Engineer/삼성전자 Date: Fri, 7 Sep 2018 01:07:09 +0000 (+0900) Subject: Introduce EMBEDDING_LOOKUP op to pure_acl (#2631) X-Git-Tag: 0.2~59 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3d9cff194a245be06559d2321bfea017bd2cc426;p=platform%2Fcore%2Fml%2Fnnfw.git Introduce EMBEDDING_LOOKUP op to pure_acl (#2631) Introduce EMBEDDING_LOOKUP op to pure_acl Signed-off-by: Yongseop Kim --- diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 71c57d2..9f17014 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -481,6 +481,7 @@ public: void visit(const ::internal::tflite::op::Pad::Node &node) override; void visit(const ::internal::tflite::op::SpaceToDepth::Node &node) override; void visit(const ::internal::tflite::op::L2Pool2D::Implicit::Node &node) override; + void visit(const ::internal::tflite::op::EmbeddingLookup::Node &node) override; private: const ::internal::tflite::operand::Set &_ctx; @@ -3505,6 +3506,12 @@ void Planner::visit(const ::internal::tflite::op::L2Pool2D::Implicit::Node &node _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::EmbeddingLookup::Node &node) +{ + // TODO Implement EMBEDDING_LOOKUP + throw std::runtime_error("Not supported"); +} + class AllocationContext final : public IAllocationContext { public: diff --git a/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.cc b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.cc new file mode 100644 index 0000000..2f658a1 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.cc @@ -0,0 +1,49 @@ +#include "internal/op/EmbeddingLookup.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace EmbeddingLookup +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace EmbeddingLookup +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace EmbeddingLookup +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 2 && outputCount == 1); + + output_index = outputs[0]; + + // Each input should be interpreted as follows: + // + // 0 -> Lookups Index + // 1 -> Values Index + lookups_index = inputs[0]; + values_index = inputs[1]; +} + +} // namespace EmbeddingLookup +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.h b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.h new file mode 100644 index 0000000..2c9e2bb --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/EmbeddingLookup.h @@ -0,0 +1,54 @@ +#ifndef __INTERNAL_OP_EMBEDDING_LOOKUP_H__ +#define __INTERNAL_OP_EMBEDDING_LOOKUP_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace EmbeddingLookup +{ + +struct Param +{ + int32_t output_index; + + int32_t lookups_index; + int32_t values_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace EmbeddingLookup +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_EMBEDDING_LOOKUP_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 7c6d326..3c9ce33 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -36,6 +36,7 @@ #include "internal/op/Pad.h" #include "internal/op/SpaceToDepth.h" #include "internal/op/L2Pool2D.h" +#include "internal/op/EmbeddingLookup.h" namespace internal { @@ -87,6 +88,7 @@ struct NodeVisitor virtual void visit(const Pad::Node &) = 0; virtual void visit(const SpaceToDepth::Node &) = 0; virtual void visit(const L2Pool2D::Implicit::Node &) = 0; + virtual void visit(const EmbeddingLookup::Node &) = 0; }; } // namespace op diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 43d770a..500b103 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -611,6 +611,20 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_EMBEDDING_LOOKUP: + { + assert(inputCount == 2); + assert(outputCount == 1); + + using internal::tflite::op::EmbeddingLookup::Param; + using internal::tflite::op::EmbeddingLookup::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } default: throw std::runtime_error{"Not supported operation"}; };