From cdd183bafc891d9526c55325ab03e11e19ac4298 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Siva=20Sai=20Vaddipati/System=20SW=20/SRI-Bangalore/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 11 Sep 2018 13:57:56 +0530 Subject: [PATCH] Introduce HashtableLookup op in PACL runtime (#2655) This commit introduces ANEURALNETWORKS_HASHTABLE_LOOKUP in PACL. Related issue: #2654 Signed-off-by: Siva Sai --- runtimes/pure_arm_compute/src/compilation.cc | 7 +++ .../src/internal/op/HashtableLookup.cc | 52 ++++++++++++++++++++ .../src/internal/op/HashtableLookup.h | 56 ++++++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 14 ++++++ 5 files changed, 131 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/HashtableLookup.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/HashtableLookup.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index c806785..74ca80d 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -500,6 +500,7 @@ public: void visit(const ::internal::tflite::op::L2Pool2D::Implicit::Node &node) override; void visit(const ::internal::tflite::op::L2Pool2D::Explicit::Node &node) override; void visit(const ::internal::tflite::op::EmbeddingLookup::Node &node) override; + void visit(const ::internal::tflite::op::HashtableLookup::Node &node) override; private: const ::internal::tflite::operand::Set &_ctx; @@ -3697,6 +3698,12 @@ void Planner::visit(const ::internal::tflite::op::EmbeddingLookup::Node &node) _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::HashtableLookup::Node &node) +{ + // TODO Implement HashtableLookup + throw std::runtime_error("Not supported"); +} + class AllocationContext final : public IAllocationContext { public: diff --git a/runtimes/pure_arm_compute/src/internal/op/HashtableLookup.cc b/runtimes/pure_arm_compute/src/internal/op/HashtableLookup.cc new file mode 100644 index 0000000..30a853a --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/HashtableLookup.cc @@ -0,0 +1,52 @@ +#include "internal/op/HashtableLookup.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace HashtableLookup +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace HashtableLookup +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace HashtableLookup +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 3 && outputCount == 2); + + output_index = outputs[0]; + hits_index = outputs[1]; + + // Each input should be interpreted as follows: + // + // 0 -> Lookups Index + // 1 -> Keys Index + // 2 -> Values Index + lookups_index = inputs[0]; + keys_index = inputs[1]; + values_index = inputs[2]; +} + +} // namespace HashtableLookup +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/HashtableLookup.h b/runtimes/pure_arm_compute/src/internal/op/HashtableLookup.h new file mode 100644 index 0000000..192da2a --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/HashtableLookup.h @@ -0,0 +1,56 @@ +#ifndef __INTERNAL_OP_HASHTABLE_LOOKUP_H__ +#define __INTERNAL_OP_HASHTABLE_LOOKUP_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace HashtableLookup +{ + +struct Param +{ + int32_t output_index; + int32_t hits_index; + + int32_t lookups_index; + int32_t values_index; + int32_t keys_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace HashtableLookup +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_HASHTABLE_LOOKUP_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 34cf70c..96307c9 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -53,6 +53,7 @@ #include "internal/op/SpaceToDepth.h" #include "internal/op/L2Pool2D.h" #include "internal/op/EmbeddingLookup.h" +#include "internal/op/HashtableLookup.h" namespace internal { @@ -106,6 +107,7 @@ struct NodeVisitor virtual void visit(const L2Pool2D::Implicit::Node &) = 0; virtual void visit(const L2Pool2D::Explicit::Node &) = 0; virtual void visit(const EmbeddingLookup::Node &) = 0; + virtual void visit(const HashtableLookup::Node &) = 0; }; } // namespace op diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index 456f813..8c155ed 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -646,6 +646,20 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_HASHTABLE_LOOKUP: + { + assert(inputCount == 3); + assert(outputCount == 2); + + using internal::tflite::op::HashtableLookup::Param; + using internal::tflite::op::HashtableLookup::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } default: throw std::runtime_error{"Not supported operation"}; }; -- 2.7.4