This commit introduces LSTM operation into pureACL.
Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
void visit(const ::internal::tflite::op::Logistic::Node &node) override;
void visit(const ::internal::tflite::op::Mean::Node &node) override;
void visit(const ::internal::tflite::op::RNN::Node &node) override;
+ void visit(const ::internal::tflite::op::LSTM::Node &node) override;
void visit(const ::internal::tflite::op::Floor::Node &node) override;
private:
_builder.addStage(stage);
}
+void Planner::visit(const ::internal::tflite::op::LSTM::Node &node)
+{
+ // TODO Implement LSTM op
+ throw std::runtime_error("Not supported, yet");
+}
+
void Planner::visit(const ::internal::tflite::op::Floor::Node &node)
{
VERBOSE(Floor) << "Configure Floor operation" << std::endl;
--- /dev/null
+#include "internal/op/Lstm.h"
+#include "internal/op/NodeVisitor.h"
+
+#include <cassert>
+
+namespace internal
+{
+namespace tflite
+{
+namespace op
+{
+namespace LSTM
+{
+
+void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
+
+} // namespace LSTM
+} // namespace op
+} // namespace tflite
+} // namespace internal
+
+namespace internal
+{
+namespace tflite
+{
+namespace op
+{
+namespace LSTM
+{
+
+Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount,
+ const uint32_t *outputs)
+{
+ assert(inputCount == 23 && outputCount == 4);
+
+ scratch_buffer_index = outputs[0];
+ output_state_out_index = outputs[1];
+ cell_state_out_index = outputs[2];
+ output_index = outputs[3];
+
+ input_index = inputs[0];
+ input_to_input_weights_index = inputs[1];
+ input_to_forget_weights_index = inputs[2];
+ input_to_cell_weights_index = inputs[3];
+ input_to_output_weights_index = inputs[4];
+ recurrent_to_input_weights_index = inputs[5];
+ recurrent_to_forget_weights_index = inputs[6];
+ recurrent_to_cell_weights_index = inputs[7];
+ recurrent_to_output_weights_index = inputs[8];
+ cell_to_input_weights_index = inputs[9];
+ cell_to_forget_weights_index = inputs[10];
+ cell_to_output_weights_index = inputs[11];
+ input_gate_bias_index = inputs[12];
+ forget_gate_bias_index = inputs[13];
+ cell_bias_index = inputs[14];
+ output_gate_bias_index = inputs[15];
+ projection_weights_index = inputs[16];
+ projection_bias_index = inputs[17];
+ output_state_in_index = inputs[18];
+ cell_state_in_index = inputs[19];
+ activation_index = inputs[20];
+ cell_threshold_index = inputs[21];
+ projection_threshold_index = inputs[22];
+}
+
+} // namespace LSTM
+} // namespace op
+} // namespace tflite
+} // namespace internal
--- /dev/null
+#ifndef __INTERNAL_OP_LSTM_H__
+#define __INTERNAL_OP_LSTM_H__
+
+#include "internal/op/Node.h"
+
+#include <cstdint>
+
+namespace internal
+{
+namespace tflite
+{
+namespace op
+{
+namespace LSTM
+{
+
+struct Param
+{
+ int32_t scratch_buffer_index;
+ int32_t output_state_out_index;
+ int32_t cell_state_out_index;
+ int32_t output_index;
+
+ int32_t input_index;
+ int32_t input_to_input_weights_index;
+ int32_t input_to_forget_weights_index;
+ int32_t input_to_cell_weights_index;
+ int32_t input_to_output_weights_index;
+ int32_t recurrent_to_input_weights_index;
+ int32_t recurrent_to_forget_weights_index;
+ int32_t recurrent_to_cell_weights_index;
+ int32_t recurrent_to_output_weights_index;
+ int32_t cell_to_input_weights_index;
+ int32_t cell_to_forget_weights_index;
+ int32_t cell_to_output_weights_index;
+ int32_t input_gate_bias_index;
+ int32_t forget_gate_bias_index;
+ int32_t cell_bias_index;
+ int32_t output_gate_bias_index;
+ int32_t projection_weights_index;
+ int32_t projection_bias_index;
+ int32_t output_state_in_index;
+ int32_t cell_state_in_index;
+ int32_t activation_index;
+ int32_t cell_threshold_index;
+ int32_t projection_threshold_index;
+
+ Param() = default;
+ Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount, const uint32_t *outputs);
+};
+
+class Node final : public op::Node
+{
+public:
+ Node(const Param ¶m) : _param(param)
+ {
+ // DO NOTHING
+ }
+
+public:
+ virtual ~Node() = default;
+
+public:
+ const Param ¶m(void) const { return _param; }
+
+public:
+ void accept(NodeVisitor &&) const override;
+
+private:
+ const Param _param;
+};
+
+} // namespace LSTM
+} // namespace op
+} // namespace tflite
+} // namespace internal
+
+#endif // __INTERNAL_OP_LSTM_H__
#include "internal/op/Logistic.h"
#include "internal/op/Mean.h"
#include "internal/op/Rnn.h"
+#include "internal/op/Lstm.h"
#include "internal/op/Floor.h"
namespace internal
virtual void visit(const Logistic::Node &) = 0;
virtual void visit(const Mean::Node &) = 0;
virtual void visit(const RNN::Node &) = 0;
+ virtual void visit(const LSTM::Node &) = 0;
virtual void visit(const Floor::Node &) = 0;
};
break;
}
+ case ANEURALNETWORKS_LSTM:
+ {
+ using internal::tflite::op::LSTM::Param;
+ using internal::tflite::op::LSTM::Node;
+
+ auto &operations = model->deref().operations();
+
+ operations.emplace_back<Node>(Param{inputCount, inputs, outputCount, outputs});
+
+ break;
+ }
case ANEURALNETWORKS_FLOOR:
{
using internal::tflite::op::Floor::Param;