From f9b019656f891a69fc9f59d3ec708d7efc254438 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9E=A5=EC=A7=80=EC=84=AD/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 5 Sep 2019 14:22:52 +0900 Subject: [PATCH] Enable LSTM op for ACL neon (#7201) This commit enables to support LSTM op for ACL neon. Signed-off-by: jiseob.jang --- .../neurun/backend/acl_neon/ConstantInitializer.cc | 83 +++++++++++ .../neurun/backend/acl_neon/ConstantInitializer.h | 1 + .../neurun/backend/acl_neon/KernelGenerator.cc | 162 +++++++++++++++++++++ runtimes/neurun/backend/acl_neon/KernelGenerator.h | 1 + runtimes/neurun/backend/acl_neon/ShapeFixer.cc | 2 + runtimes/neurun/backend/acl_neon/ShapeFixer.h | 1 + tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon | 3 +- 7 files changed, 252 insertions(+), 1 deletion(-) diff --git a/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc b/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc index ed8b31b..d83b7c0 100644 --- a/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc +++ b/runtimes/neurun/backend/acl_neon/ConstantInitializer.cc @@ -78,6 +78,89 @@ void ConstantInitializer::visit(const model::operation::FullyConnectedNode &node registerCopyInitializer(bias_index, bias_obj); } +void ConstantInitializer::visit(const model::operation::LSTMNode &node) +{ + const auto &input_to_input_weights_index = + node.getInputs().at(model::operation::LSTMNode::INPUT_TO_INPUT_WEIGHTS); + const auto &input_to_input_weights_obj = _operands.at(input_to_input_weights_index); + registerCopyInitializer(input_to_input_weights_index, input_to_input_weights_obj); + + const auto &input_to_forget_weights_index = + node.getInputs().at(model::operation::LSTMNode::INPUT_TO_FORGET_WEIGHTS); + const auto &input_to_forget_weights_obj = _operands.at(input_to_forget_weights_index); + registerCopyInitializer(input_to_forget_weights_index, input_to_forget_weights_obj); + + const auto &input_to_cell_weights_index = + node.getInputs().at(model::operation::LSTMNode::INPUT_TO_CELL_WEIGHTS); + const auto &input_to_cell_weights_obj = _operands.at(input_to_cell_weights_index); + registerCopyInitializer(input_to_cell_weights_index, input_to_cell_weights_obj); + + const auto &input_to_output_weights_index = + node.getInputs().at(model::operation::LSTMNode::INPUT_TO_OUTPUT_WEIGHTS); + const auto &input_to_output_weights_obj = _operands.at(input_to_output_weights_index); + registerCopyInitializer(input_to_output_weights_index, input_to_output_weights_obj); + + const auto &recurrent_to_input_weights_index = + node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_INPUT_WEIGHTS); + const auto &recurrent_to_input_weights_obj = _operands.at(recurrent_to_input_weights_index); + registerCopyInitializer(recurrent_to_input_weights_index, recurrent_to_input_weights_obj); + + const auto &recurrent_to_forget_weights_index = + node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_FORGET_WEIGHTS); + const auto &recurrent_to_forget_weights_obj = _operands.at(recurrent_to_forget_weights_index); + registerCopyInitializer(recurrent_to_forget_weights_index, recurrent_to_forget_weights_obj); + + const auto &recurrent_to_cell_weights_index = + node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_CELL_WEIGHTS); + const auto &recurrent_to_cell_weights_obj = _operands.at(recurrent_to_cell_weights_index); + registerCopyInitializer(recurrent_to_cell_weights_index, recurrent_to_cell_weights_obj); + + const auto &recurrent_to_output_weights_index = + node.getInputs().at(model::operation::LSTMNode::RECURRENT_TO_OUTPUT_WEIGHTS); + const auto &recurrent_to_output_weights_obj = _operands.at(recurrent_to_output_weights_index); + registerCopyInitializer(recurrent_to_output_weights_index, recurrent_to_output_weights_obj); + + const auto &cell_to_input_weights_index = + node.getInputs().at(model::operation::LSTMNode::CELL_TO_INPUT_WEIGHTS); + const auto &cell_to_input_weights_obj = _operands.at(cell_to_input_weights_index); + registerCopyInitializer(cell_to_input_weights_index, cell_to_input_weights_obj); + + const auto &cell_to_forget_weights_index = + node.getInputs().at(model::operation::LSTMNode::CELL_TO_FORGET_WEIGHTS); + const auto &cell_to_forget_weights_obj = _operands.at(cell_to_forget_weights_index); + registerCopyInitializer(cell_to_forget_weights_index, cell_to_forget_weights_obj); + + const auto &cell_to_output_weights_index = + node.getInputs().at(model::operation::LSTMNode::CELL_TO_OUTPUT_WEIGHTS); + const auto &cell_to_output_weights_obj = _operands.at(cell_to_output_weights_index); + registerCopyInitializer(cell_to_output_weights_index, cell_to_output_weights_obj); + + const auto &input_gate_bias_index = + node.getInputs().at(model::operation::LSTMNode::INPUT_GATE_BIAS); + const auto &input_gate_bias_obj = _operands.at(input_gate_bias_index); + registerCopyInitializer(input_gate_bias_index, input_gate_bias_obj); + + const auto &forget_gate_bias_index = + node.getInputs().at(model::operation::LSTMNode::FORGET_GATE_BIAS); + const auto &forget_gate_bias_obj = _operands.at(forget_gate_bias_index); + registerCopyInitializer(forget_gate_bias_index, forget_gate_bias_obj); + + const auto &output_gate_bias_index = + node.getInputs().at(model::operation::LSTMNode::OUTPUT_GATE_BIAS); + const auto &output_gate_bias_obj = _operands.at(output_gate_bias_index); + registerCopyInitializer(output_gate_bias_index, output_gate_bias_obj); + + const auto &projection_weights_index = + node.getInputs().at(model::operation::LSTMNode::PROJECTION_WEIGHTS); + const auto &projection_weights_obj = _operands.at(projection_weights_index); + registerCopyInitializer(projection_weights_index, projection_weights_obj); + + const auto &projection_bias_index = + node.getInputs().at(model::operation::LSTMNode::PROJECTION_BIAS); + const auto &projection_bias_obj = _operands.at(projection_bias_index); + registerCopyInitializer(projection_bias_index, projection_bias_obj); +} + void ConstantInitializer::visit(const model::operation::RNNNode &node) { const auto &weights_index = node.getInputs().at(model::operation::RNNNode::WEIGHTS); diff --git a/runtimes/neurun/backend/acl_neon/ConstantInitializer.h b/runtimes/neurun/backend/acl_neon/ConstantInitializer.h index 448fe13..cdd94f7 100644 --- a/runtimes/neurun/backend/acl_neon/ConstantInitializer.h +++ b/runtimes/neurun/backend/acl_neon/ConstantInitializer.h @@ -41,6 +41,7 @@ public: void visit(const model::operation::Conv2DNode &) override; void visit(const model::operation::DepthwiseConv2DNode &) override; void visit(const model::operation::FullyConnectedNode &) override; + void visit(const model::operation::LSTMNode &) override; void visit(const model::operation::RNNNode &) override; void visit(const model::operation::TransposeConvNode &) override; diff --git a/runtimes/neurun/backend/acl_neon/KernelGenerator.cc b/runtimes/neurun/backend/acl_neon/KernelGenerator.cc index f2e69c1..8587eb4 100644 --- a/runtimes/neurun/backend/acl_neon/KernelGenerator.cc +++ b/runtimes/neurun/backend/acl_neon/KernelGenerator.cc @@ -731,6 +731,168 @@ void KernelGenerator::visit(const model::operation::LogisticNode &node) _execution_builder->append(std::move(acl_fn)); } +void KernelGenerator::visit(const model::operation::LSTMNode &node) +{ + // TODO Support dynamic rnn + // TODO Fix subtle error in the case of non-CIFG, non-peephole and No Projection. + const auto scratch_buffer_index{ + node.getOutputs().at(model::operation::LSTMNode::Output::SCRATCH_BUFFER)}; + const auto output_state_out_index{ + node.getOutputs().at(model::operation::LSTMNode::Output::OUTPUT_STATE_OUT)}; + const auto cell_state_out_index{ + node.getOutputs().at(model::operation::LSTMNode::Output::CELL_STATE_OUT)}; + const auto output_index{node.getOutputs().at(model::operation::LSTMNode::Output::OUTPUT)}; + + const auto input_index{node.getInputs().at(model::operation::LSTMNode::Input::INPUT)}; + const auto input_to_input_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::INPUT_TO_INPUT_WEIGHTS)}; // optional + const auto input_to_forget_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::INPUT_TO_FORGET_WEIGHTS)}; + const auto input_to_cell_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::INPUT_TO_CELL_WEIGHTS)}; + const auto input_to_output_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::INPUT_TO_OUTPUT_WEIGHTS)}; + const auto recurrent_to_input_weights_index{node.getInputs().at( + model::operation::LSTMNode::Input::RECURRENT_TO_INPUT_WEIGHTS)}; // optional + const auto recurrent_to_forget_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::RECURRENT_TO_FORGET_WEIGHTS)}; + const auto recurrent_to_cell_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::RECURRENT_TO_CELL_WEIGHTS)}; + const auto recurrent_to_output_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::RECURRENT_TO_OUTPUT_WEIGHTS)}; + const auto cell_to_input_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::CELL_TO_INPUT_WEIGHTS)}; // optional + const auto cell_to_forget_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::CELL_TO_FORGET_WEIGHTS)}; // optional + const auto cell_to_output_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::CELL_TO_OUTPUT_WEIGHTS)}; // optional + const auto input_gate_bias_index{ + node.getInputs().at(model::operation::LSTMNode::Input::INPUT_GATE_BIAS)}; + const auto forget_gate_bias_index{ + node.getInputs().at(model::operation::LSTMNode::Input::FORGET_GATE_BIAS)}; + const auto cell_bias_index{node.getInputs().at(model::operation::LSTMNode::Input::CELL_BIAS)}; + const auto output_gate_bias_index{ + node.getInputs().at(model::operation::LSTMNode::Input::OUTPUT_GATE_BIAS)}; + const auto projection_weights_index{ + node.getInputs().at(model::operation::LSTMNode::Input::PROJECTION_WEIGHTS)}; // optional + const auto projection_bias_index{ + node.getInputs().at(model::operation::LSTMNode::Input::PROJECTION_BIAS)}; // optional + const auto output_state_in_index{ + node.getInputs().at(model::operation::LSTMNode::Input::OUTPUT_STATE_IN)}; + const auto cell_state_in_index{ + node.getInputs().at(model::operation::LSTMNode::Input::CELL_STATE_IN)}; + const auto cell_threshold = node.param().cell_threshold; + const auto projection_threshold = node.param().projection_threshold; + + bool has_input_to_input_weights = _ctx.at(input_to_input_weights_index).shape().dim(0) != 0 && + _ctx.at(input_to_input_weights_index).shape().dim(1) != 0; + bool has_recurrent_to_input_weights = + _ctx.at(recurrent_to_input_weights_index).shape().dim(0) != 0 && + _ctx.at(recurrent_to_input_weights_index).shape().dim(1) != 0; + bool has_cell_to_forget_weights = _ctx.at(cell_to_forget_weights_index).shape().dim(0) != 0; + bool has_cell_to_output_weights = _ctx.at(cell_to_output_weights_index).shape().dim(0) != 0; + bool has_projection_weights = _ctx.at(projection_weights_index).shape().dim(0) != 0 && + _ctx.at(projection_weights_index).shape().dim(1) != 0; + bool has_projection_bias = _ctx.at(projection_bias_index).shape().dim(0); + + // NOTE The input_to_input_weights and the recurrent_to_input_weights do not exist in CIFG. + // true: no CIFG + // false: CIFG + // NOTE The cell_to_input_weights does not exist in non-peephole although regular LSTM(non-CIFG). + bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights; + + // NOTE The cell_to_forget_weights and the cell_to_output_weights exist in peephole. + // But the cell_to_input_weights does not exist in regular CIFG although peephole. + // true: peephole + // false: no peephole + bool has_peephole_param = has_cell_to_forget_weights && has_cell_to_output_weights; + + // NOTE Although the projection weights has data the projection bias may not have data. + bool has_projection_param = has_projection_weights; + + const auto activation = node.param().activation; + const auto cell_clip = cell_threshold; + const auto projection_clip = projection_threshold; + assert(cell_clip >= 0.f && projection_clip >= 0.f); + + auto scratch_buffer_alloc = _tensor_builder->at(scratch_buffer_index).get(); + auto output_state_out_alloc = _tensor_builder->at(output_state_out_index).get(); + auto cell_state_out_alloc = _tensor_builder->at(cell_state_out_index).get(); + auto output_alloc = _tensor_builder->at(output_index).get(); + + auto input_alloc = _tensor_builder->at(input_index).get(); + + auto input_to_forget_weights_alloc = _tensor_builder->at(input_to_forget_weights_index).get(); + auto input_to_cell_weights_alloc = _tensor_builder->at(input_to_cell_weights_index).get(); + auto input_to_output_weights_alloc = _tensor_builder->at(input_to_output_weights_index).get(); + auto recurrent_to_forget_weights_alloc = + _tensor_builder->at(recurrent_to_forget_weights_index).get(); + auto recurrent_to_cell_weights_alloc = _tensor_builder->at(recurrent_to_cell_weights_index).get(); + auto recurrent_to_output_weights_alloc = + _tensor_builder->at(recurrent_to_output_weights_index).get(); + + auto forget_gate_bias_alloc = _tensor_builder->at(forget_gate_bias_index).get(); + auto cell_bias_alloc = _tensor_builder->at(cell_bias_index).get(); + auto output_gate_bias_alloc = _tensor_builder->at(output_gate_bias_index).get(); + auto output_state_in_alloc = _tensor_builder->at(output_state_in_index).get(); + auto cell_state_in_alloc = _tensor_builder->at(cell_state_in_index).get(); + + auto act_info = ::neurun::backend::acl_common::asActivationLayerInfo(activation); + + std::unique_ptr<::arm_compute::IFunction> fn; + + auto l = nnfw::cpp14::make_unique<::arm_compute::NELSTMLayer>(); + + ::arm_compute::LSTMParams<::arm_compute::ITensor> lstm_params{}; + if (has_cifg_param) + { + auto input_to_input_weights_alloc = + _tensor_builder->at(input_to_input_weights_index).get(); // optional + auto recurrent_to_input_weights_alloc = + _tensor_builder->at(recurrent_to_input_weights_index).get(); // optional + auto cell_to_input_weights_handle = + has_peephole_param ? _tensor_builder->at(cell_to_input_weights_index).get()->handle() + : nullptr; // optional (non-cifg && peephole) + auto input_gate_bias_alloc = _tensor_builder->at(input_gate_bias_index).get(); // optional + lstm_params.set_cifg_params(input_to_input_weights_alloc->handle(), + recurrent_to_input_weights_alloc->handle(), + cell_to_input_weights_handle, input_gate_bias_alloc->handle()); + } + if (has_peephole_param) + { + auto cell_to_forget_weights_alloc = + _tensor_builder->at(cell_to_forget_weights_index).get(); // optional + auto cell_to_output_weights_alloc = + _tensor_builder->at(cell_to_output_weights_index).get(); // optional + lstm_params.set_peephole_params(cell_to_forget_weights_alloc->handle(), + cell_to_output_weights_alloc->handle()); + } + if (has_projection_param) + { + auto projection_weights_alloc = _tensor_builder->at(projection_weights_index).get(); // optional + auto projection_bias_handle = has_projection_bias + ? _tensor_builder->at(projection_bias_index).get()->handle() + : nullptr; // optional + lstm_params.set_projection_params(projection_weights_alloc->handle(), projection_bias_handle); + } + + l->configure( + input_alloc->handle(), input_to_forget_weights_alloc->handle(), + input_to_cell_weights_alloc->handle(), input_to_output_weights_alloc->handle(), + recurrent_to_forget_weights_alloc->handle(), recurrent_to_cell_weights_alloc->handle(), + recurrent_to_output_weights_alloc->handle(), forget_gate_bias_alloc->handle(), + cell_bias_alloc->handle(), output_gate_bias_alloc->handle(), output_state_in_alloc->handle(), + cell_state_in_alloc->handle(), scratch_buffer_alloc->handle(), + output_state_out_alloc->handle(), cell_state_out_alloc->handle(), output_alloc->handle(), + lstm_params, act_info, cell_clip, projection_clip); + + fn = std::move(l); + + auto acl_fn = asAclFunction(std::move(fn)); + + _execution_builder->append(std::move(acl_fn)); +} + void KernelGenerator::visit(const model::operation::MulNode &node) { const auto ofm_index{node.getOutputs().at(0)}; diff --git a/runtimes/neurun/backend/acl_neon/KernelGenerator.h b/runtimes/neurun/backend/acl_neon/KernelGenerator.h index 7d4fae5..5554d0a 100644 --- a/runtimes/neurun/backend/acl_neon/KernelGenerator.h +++ b/runtimes/neurun/backend/acl_neon/KernelGenerator.h @@ -50,6 +50,7 @@ public: void visit(const model::operation::LocalResponseNormalizationNode &) override; void visit(const model::operation::LogicalNotNode &) override; void visit(const model::operation::LogisticNode &) override; + void visit(const model::operation::LSTMNode &) override; void visit(const model::operation::MulNode &) override; void visit(const model::operation::ReLUNode &) override; void visit(const model::operation::ReLU1Node &) override; diff --git a/runtimes/neurun/backend/acl_neon/ShapeFixer.cc b/runtimes/neurun/backend/acl_neon/ShapeFixer.cc index 30e8eaa..6c23123 100644 --- a/runtimes/neurun/backend/acl_neon/ShapeFixer.cc +++ b/runtimes/neurun/backend/acl_neon/ShapeFixer.cc @@ -106,6 +106,8 @@ void ShapeFixer::visit(const model::operation::LogicalNotNode &) { /* DO NOTHING void ShapeFixer::visit(const model::operation::LogisticNode &) { /* DO NOTHING */} +void ShapeFixer::visit(const model::operation::LSTMNode &) { /* DO NOTHING */} + void ShapeFixer::visit(const model::operation::MulNode &node) { const auto lhs_index{node.getInputs().at(model::operation::MulNode::Input::LHS)}; diff --git a/runtimes/neurun/backend/acl_neon/ShapeFixer.h b/runtimes/neurun/backend/acl_neon/ShapeFixer.h index a099ed5..b37478c 100644 --- a/runtimes/neurun/backend/acl_neon/ShapeFixer.h +++ b/runtimes/neurun/backend/acl_neon/ShapeFixer.h @@ -52,6 +52,7 @@ public: void visit(const model::operation::LocalResponseNormalizationNode &) override; void visit(const model::operation::LogicalNotNode &) override; void visit(const model::operation::LogisticNode &) override; + void visit(const model::operation::LSTMNode &) override; void visit(const model::operation::MulNode &) override; void visit(const model::operation::ReLUNode &) override; void visit(const model::operation::ReLU1Node &) override; diff --git a/tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon b/tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon index ce85a72..758a609 100644 --- a/tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon +++ b/tests/nnapi/nnapi_gtest.skip.armv7l-linux.acl_neon @@ -14,7 +14,6 @@ GeneratedTests.hashtable_lookup* GeneratedTests.logical_and_ex* GeneratedTests.logical_or_ex* GeneratedTests.lsh_projection* -GeneratedTests.lstm* GeneratedTests.mobilenet* GeneratedTests.neg* GeneratedTests.notequal* @@ -38,3 +37,5 @@ GeneratedTests.pack* # Float error GeneratedTests.exp_ex_1D_float GeneratedTests.exp_ex_2D_float +# Unsupported optional input that has shape +GeneratedTests.lstm2* -- 2.7.4