Implement RNN operation in pureacl (#2403)
author장지섭/동작제어Lab(SR)/Engineer/삼성전자 <jiseob.jang@samsung.com>
Wed, 22 Aug 2018 05:45:31 +0000 (14:45 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Wed, 22 Aug 2018 05:45:31 +0000 (14:45 +0900)
This commit implements RNN operation in pureacl.

Signed-off-by: jiseob.jang <jiseob.jang@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc
runtimes/pure_arm_compute/src/internal/arm_compute/Cast.h
runtimes/pure_arm_compute/src/internal/op/Rnn.cc

index 93e6b4e..82a8648 100644 (file)
@@ -21,7 +21,9 @@
 #include <arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h>
 #include <arm_compute/runtime/CL/functions/CLDequantizationLayer.h>
 #include <arm_compute/runtime/CL/functions/CLReductionMean.h>
+#include <arm_compute/runtime/CL/functions/CLRNNLayer.h>
 #include <arm_compute/runtime/CL/functions/CLFloor.h>
+#include <arm_compute/runtime/CL/functions/CLCopy.h>
 
 #include <arm_compute/runtime/SubTensor.h>
 #include <arm_compute/runtime/NEON/functions/NESoftmaxLayer.h>
@@ -3062,7 +3064,122 @@ void Planner::visit(const ::internal::tflite::op::Mean::Node &node)
 
 void Planner::visit(const ::internal::tflite::op::RNN::Node &node)
 {
-  // TODO Implement RNN op
+  const ::internal::tflite::operand::Index output_index{node.param().output_index};
+  const ::internal::tflite::operand::Index hidden_state_out_index{
+      node.param().hidden_state_out_index};
+
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+  const ::internal::tflite::operand::Index weights_index{node.param().weights_index};
+  const ::internal::tflite::operand::Index recurrent_weights_index{
+      node.param().recurrent_weights_index};
+  const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+  const ::internal::tflite::operand::Index hidden_state_in_index{
+      node.param().hidden_state_in_index};
+  const ::internal::tflite::operand::Index fused_activation_index{
+      node.param().fused_activation_index};
+
+  assert(_ctx.at(output_index).shape().rank() == 2 &&
+         _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
+         _ctx.at(input_index).shape().rank() == 2 && _ctx.at(weights_index).shape().rank() == 2 &&
+         _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
+         _ctx.at(hidden_state_in_index).shape().rank() == 2);
+  assert(_ctx.at(bias_index).shape().rank() == 1);
+
+  const auto batch_size = _ctx.at(output_index).shape().dim(0);
+  assert(batch_size == _ctx.at(input_index).shape().dim(0) &&
+         batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
+         batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
+  assert(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
+
+  const auto num_units = _ctx.at(output_index).shape().dim(1);
+  assert(num_units == _ctx.at(weights_index).shape().dim(0) &&
+         num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
+         num_units == _ctx.at(bias_index).shape().dim(0));
+  assert(num_units == _ctx.at(output_index).shape().dim(1) &&
+         num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
+         num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
+         num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
+
+  // Set Shape Constraints and TensorInfo
+  _builder.addShapeConstr(
+      output_index, asTensorInfo(_ctx.at(output_index).shape(), _ctx.at(output_index).type()));
+  _builder.addShapeConstr(hidden_state_out_index,
+                          asTensorInfo(_ctx.at(hidden_state_out_index).shape(),
+                                       _ctx.at(hidden_state_out_index).type()));
+  _builder.addShapeConstr(input_index,
+                          asTensorInfo(_ctx.at(input_index).shape(), _ctx.at(input_index).type()));
+  _builder.addShapeConstr(
+      weights_index, asTensorInfo(_ctx.at(weights_index).shape(), _ctx.at(weights_index).type()));
+  _builder.addShapeConstr(recurrent_weights_index,
+                          asTensorInfo(_ctx.at(recurrent_weights_index).shape(),
+                                       _ctx.at(recurrent_weights_index).type()));
+  _builder.addShapeConstr(bias_index,
+                          asTensorInfo(_ctx.at(bias_index).shape(), _ctx.at(bias_index).type()));
+  _builder.addShapeConstr(
+      hidden_state_in_index,
+      asTensorInfo(_ctx.at(hidden_state_in_index).shape(), _ctx.at(hidden_state_in_index).type()));
+
+  // Construct operation parameters
+  struct Param
+  {
+    int output_index;
+    int hidden_state_out_index;
+
+    int input_index;
+    int weights_index;
+    int recurrent_weights_index;
+    int bias_index;
+    int hidden_state_in_index;
+
+    FuseCode activation;
+  };
+
+  Param param;
+
+  param.output_index = output_index.asInt();
+  param.hidden_state_out_index = hidden_state_out_index.asInt();
+
+  param.input_index = input_index.asInt();
+  param.weights_index = weights_index.asInt();
+  param.recurrent_weights_index = recurrent_weights_index.asInt();
+  param.bias_index = bias_index.asInt();
+  param.hidden_state_in_index = hidden_state_in_index.asInt();
+  param.activation = static_cast<FuseCode>(_ctx.at(fused_activation_index).asScalar<int32_t>());
+
+  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+    auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
+    auto hidden_state_out_alloc =
+        ctx.at(::internal::tflite::operand::Index{param.hidden_state_out_index});
+    auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
+    auto weights_alloc = ctx.at(::internal::tflite::operand::Index{param.weights_index});
+    auto recurrent_weights_alloc =
+        ctx.at(::internal::tflite::operand::Index{param.recurrent_weights_index});
+    auto bias_alloc = ctx.at(::internal::tflite::operand::Index{param.bias_index});
+    auto hidden_state_in_alloc =
+        ctx.at(::internal::tflite::operand::Index{param.hidden_state_in_index});
+    auto act_info = asActivationInfo(param.activation);
+
+    if (::internal::arm_compute::isGpuMode())
+    {
+      std::unique_ptr<::arm_compute::CLCopy> copy_fn{new ::arm_compute::CLCopy};
+      copy_fn->configure(CAST_CL(hidden_state_in_alloc), CAST_CL(hidden_state_out_alloc));
+      builder.append("COPY", std::move(copy_fn));
+
+      std::unique_ptr<::arm_compute::CLRNNLayer> rnn_fn{new ::arm_compute::CLRNNLayer};
+
+      // The hidden_state_in's data must be copied to hidden_state_out_alloc before fn->run() is
+      // performed.
+      rnn_fn->configure(CAST_CL(input_alloc), CAST_CL(weights_alloc),
+                        CAST_CL(recurrent_weights_alloc), CAST_CL(bias_alloc),
+                        CAST_CL(hidden_state_out_alloc), CAST_CL(output_alloc), act_info);
+
+      builder.append("RNN", std::move(rnn_fn));
+    }
+    else
+      throw std::runtime_error("Not supported, yet");
+  };
+
+  _builder.addStage(stage);
 }
 
 void Planner::visit(const ::internal::tflite::op::Floor::Node &node)
index e9f693b..9872456 100644 (file)
@@ -180,6 +180,27 @@ inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand:
   }
 }
 
+::arm_compute::ActivationLayerInfo asActivationInfo(FuseCode code)
+{
+  switch (code)
+  {
+    case ANEURALNETWORKS_FUSED_NONE:
+      return ::arm_compute::ActivationLayerInfo{};
+    case ANEURALNETWORKS_FUSED_RELU:
+      return ::arm_compute::ActivationLayerInfo{
+          ::arm_compute::ActivationLayerInfo::ActivationFunction::RELU};
+    case ANEURALNETWORKS_FUSED_RELU1:
+      return ::arm_compute::ActivationLayerInfo{
+          ::arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 1.0f, -1.0f};
+    case ANEURALNETWORKS_FUSED_RELU6:
+      return ::arm_compute::ActivationLayerInfo{
+          ::arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.0f, 0.0f};
+    default:
+      throw std::runtime_error("Not supported, yet");
+      break;
+  }
+}
+
 ::arm_compute::QuantizationInfo asQuantizationInfo(const float scale, const int32_t offset)
 {
   return ::arm_compute::QuantizationInfo(scale, offset);
index 54f483a..e0a2b29 100644 (file)
@@ -34,7 +34,7 @@ Param::Param(uint32_t inputCount, const uint32_t *inputs, uint32_t outputCount,
   assert(inputCount == 6 && outputCount == 2);
 
   output_index = outputs[0];
-  hidden_state_out_index = inputs[4];
+  hidden_state_out_index = outputs[1];
 
   input_index = inputs[0];
   weights_index = inputs[1];