From e32301e63db0ba7c8362e535da36a7c35eb8ceb5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/=EB=8F=99=EC=9E=91=EC=A0=9C?= =?utf8?q?=EC=96=B4Lab=28SR=29/Engineer/=EC=82=BC=EC=84=B1=EC=A0=84?= =?utf8?q?=EC=9E=90?= Date: Wed, 22 Aug 2018 15:45:43 +0900 Subject: [PATCH] Introduce LSTM operation into pureACL (#2408) This commit introduces LSTM operation into pureACL. Signed-off-by: jiseob.jang --- runtimes/pure_arm_compute/src/compilation.cc | 7 ++ runtimes/pure_arm_compute/src/internal/op/Lstm.cc | 69 +++++++++++++++++++ runtimes/pure_arm_compute/src/internal/op/Lstm.h | 78 ++++++++++++++++++++++ .../pure_arm_compute/src/internal/op/NodeVisitor.h | 2 + runtimes/pure_arm_compute/src/model.cc | 11 +++ 5 files changed, 167 insertions(+) create mode 100644 runtimes/pure_arm_compute/src/internal/op/Lstm.cc create mode 100644 runtimes/pure_arm_compute/src/internal/op/Lstm.h diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 82a8648..e82a870 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -454,6 +454,7 @@ public: void visit(const ::internal::tflite::op::Logistic::Node &node) override; void visit(const ::internal::tflite::op::Mean::Node &node) override; void visit(const ::internal::tflite::op::RNN::Node &node) override; + void visit(const ::internal::tflite::op::LSTM::Node &node) override; void visit(const ::internal::tflite::op::Floor::Node &node) override; private: @@ -3182,6 +3183,12 @@ void Planner::visit(const ::internal::tflite::op::RNN::Node &node) _builder.addStage(stage); } +void Planner::visit(const ::internal::tflite::op::LSTM::Node &node) +{ + // TODO Implement LSTM op + throw std::runtime_error("Not supported, yet"); +} + void Planner::visit(const ::internal::tflite::op::Floor::Node &node) { VERBOSE(Floor) << "Configure Floor operation" << std::endl; diff --git a/runtimes/pure_arm_compute/src/internal/op/Lstm.cc b/runtimes/pure_arm_compute/src/internal/op/Lstm.cc new file mode 100644 index 0000000..85cb127 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Lstm.cc @@ -0,0 +1,69 @@ +#include "internal/op/Lstm.h" +#include "internal/op/NodeVisitor.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace LSTM +{ + +void Node::accept(NodeVisitor &&v) const { v.visit(*this); } + +} // namespace LSTM +} // namespace op +} // namespace tflite +} // namespace internal + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace LSTM +{ + +Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, + const uint32_t *outputs) +{ + assert(inputCount == 23 && outputCount == 4); + + scratch_buffer_index = outputs[0]; + output_state_out_index = outputs[1]; + cell_state_out_index = outputs[2]; + output_index = outputs[3]; + + input_index = inputs[0]; + input_to_input_weights_index = inputs[1]; + input_to_forget_weights_index = inputs[2]; + input_to_cell_weights_index = inputs[3]; + input_to_output_weights_index = inputs[4]; + recurrent_to_input_weights_index = inputs[5]; + recurrent_to_forget_weights_index = inputs[6]; + recurrent_to_cell_weights_index = inputs[7]; + recurrent_to_output_weights_index = inputs[8]; + cell_to_input_weights_index = inputs[9]; + cell_to_forget_weights_index = inputs[10]; + cell_to_output_weights_index = inputs[11]; + input_gate_bias_index = inputs[12]; + forget_gate_bias_index = inputs[13]; + cell_bias_index = inputs[14]; + output_gate_bias_index = inputs[15]; + projection_weights_index = inputs[16]; + projection_bias_index = inputs[17]; + output_state_in_index = inputs[18]; + cell_state_in_index = inputs[19]; + activation_index = inputs[20]; + cell_threshold_index = inputs[21]; + projection_threshold_index = inputs[22]; +} + +} // namespace LSTM +} // namespace op +} // namespace tflite +} // namespace internal diff --git a/runtimes/pure_arm_compute/src/internal/op/Lstm.h b/runtimes/pure_arm_compute/src/internal/op/Lstm.h new file mode 100644 index 0000000..a759ad9 --- /dev/null +++ b/runtimes/pure_arm_compute/src/internal/op/Lstm.h @@ -0,0 +1,78 @@ +#ifndef __INTERNAL_OP_LSTM_H__ +#define __INTERNAL_OP_LSTM_H__ + +#include "internal/op/Node.h" + +#include + +namespace internal +{ +namespace tflite +{ +namespace op +{ +namespace LSTM +{ + +struct Param +{ + int32_t scratch_buffer_index; + int32_t output_state_out_index; + int32_t cell_state_out_index; + int32_t output_index; + + int32_t input_index; + int32_t input_to_input_weights_index; + int32_t input_to_forget_weights_index; + int32_t input_to_cell_weights_index; + int32_t input_to_output_weights_index; + int32_t recurrent_to_input_weights_index; + int32_t recurrent_to_forget_weights_index; + int32_t recurrent_to_cell_weights_index; + int32_t recurrent_to_output_weights_index; + int32_t cell_to_input_weights_index; + int32_t cell_to_forget_weights_index; + int32_t cell_to_output_weights_index; + int32_t input_gate_bias_index; + int32_t forget_gate_bias_index; + int32_t cell_bias_index; + int32_t output_gate_bias_index; + int32_t projection_weights_index; + int32_t projection_bias_index; + int32_t output_state_in_index; + int32_t cell_state_in_index; + int32_t activation_index; + int32_t cell_threshold_index; + int32_t projection_threshold_index; + + Param() = default; + Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs); +}; + +class Node final : public op::Node +{ +public: + Node(const Param ¶m) : _param(param) + { + // DO NOTHING + } + +public: + virtual ~Node() = default; + +public: + const Param ¶m(void) const { return _param; } + +public: + void accept(NodeVisitor &&) const override; + +private: + const Param _param; +}; + +} // namespace LSTM +} // namespace op +} // namespace tflite +} // namespace internal + +#endif // __INTERNAL_OP_LSTM_H__ diff --git a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h index 55ef997..9a523ef 100644 --- a/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h +++ b/runtimes/pure_arm_compute/src/internal/op/NodeVisitor.h @@ -28,6 +28,7 @@ #include "internal/op/Logistic.h" #include "internal/op/Mean.h" #include "internal/op/Rnn.h" +#include "internal/op/Lstm.h" #include "internal/op/Floor.h" namespace internal @@ -72,6 +73,7 @@ struct NodeVisitor virtual void visit(const Logistic::Node &) = 0; virtual void visit(const Mean::Node &) = 0; virtual void visit(const RNN::Node &) = 0; + virtual void visit(const LSTM::Node &) = 0; virtual void visit(const Floor::Node &) = 0; }; diff --git a/runtimes/pure_arm_compute/src/model.cc b/runtimes/pure_arm_compute/src/model.cc index e04b5d9..6355bfd 100644 --- a/runtimes/pure_arm_compute/src/model.cc +++ b/runtimes/pure_arm_compute/src/model.cc @@ -470,6 +470,17 @@ int ANeuralNetworksModel_addOperation(ANeuralNetworksModel *model, break; } + case ANEURALNETWORKS_LSTM: + { + using internal::tflite::op::LSTM::Param; + using internal::tflite::op::LSTM::Node; + + auto &operations = model->deref().operations(); + + operations.emplace_back(Param{inputCount, inputs, outputCount, outputs}); + + break; + } case ANEURALNETWORKS_FLOOR: { using internal::tflite::op::Floor::Param; -- 2.7.4