#include <arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h>
#include <arm_compute/runtime/CL/functions/CLDequantizationLayer.h>
#include <arm_compute/runtime/CL/functions/CLReductionMean.h>
+#include <arm_compute/runtime/CL/functions/CLRNNLayer.h>
#include <arm_compute/runtime/CL/functions/CLFloor.h>
+#include <arm_compute/runtime/CL/functions/CLCopy.h>
#include <arm_compute/runtime/SubTensor.h>
#include <arm_compute/runtime/NEON/functions/NESoftmaxLayer.h>
void Planner::visit(const ::internal::tflite::op::RNN::Node &node)
{
- // TODO Implement RNN op
+ const ::internal::tflite::operand::Index output_index{node.param().output_index};
+ const ::internal::tflite::operand::Index hidden_state_out_index{
+ node.param().hidden_state_out_index};
+
+ const ::internal::tflite::operand::Index input_index{node.param().input_index};
+ const ::internal::tflite::operand::Index weights_index{node.param().weights_index};
+ const ::internal::tflite::operand::Index recurrent_weights_index{
+ node.param().recurrent_weights_index};
+ const ::internal::tflite::operand::Index bias_index{node.param().bias_index};
+ const ::internal::tflite::operand::Index hidden_state_in_index{
+ node.param().hidden_state_in_index};
+ const ::internal::tflite::operand::Index fused_activation_index{
+ node.param().fused_activation_index};
+
+ assert(_ctx.at(output_index).shape().rank() == 2 &&
+ _ctx.at(hidden_state_out_index).shape().rank() == 2 &&
+ _ctx.at(input_index).shape().rank() == 2 && _ctx.at(weights_index).shape().rank() == 2 &&
+ _ctx.at(recurrent_weights_index).shape().rank() == 2 &&
+ _ctx.at(hidden_state_in_index).shape().rank() == 2);
+ assert(_ctx.at(bias_index).shape().rank() == 1);
+
+ const auto batch_size = _ctx.at(output_index).shape().dim(0);
+ assert(batch_size == _ctx.at(input_index).shape().dim(0) &&
+ batch_size == _ctx.at(hidden_state_in_index).shape().dim(0) &&
+ batch_size == _ctx.at(hidden_state_out_index).shape().dim(0));
+ assert(_ctx.at(input_index).shape().dim(1) == _ctx.at(weights_index).shape().dim(1));
+
+ const auto num_units = _ctx.at(output_index).shape().dim(1);
+ assert(num_units == _ctx.at(weights_index).shape().dim(0) &&
+ num_units == _ctx.at(recurrent_weights_index).shape().dim(0) &&
+ num_units == _ctx.at(bias_index).shape().dim(0));
+ assert(num_units == _ctx.at(output_index).shape().dim(1) &&
+ num_units == _ctx.at(recurrent_weights_index).shape().dim(1) &&
+ num_units == _ctx.at(hidden_state_in_index).shape().dim(1) &&
+ num_units == _ctx.at(hidden_state_out_index).shape().dim(1));
+
+ // Set Shape Constraints and TensorInfo
+ _builder.addShapeConstr(
+ output_index, asTensorInfo(_ctx.at(output_index).shape(), _ctx.at(output_index).type()));
+ _builder.addShapeConstr(hidden_state_out_index,
+ asTensorInfo(_ctx.at(hidden_state_out_index).shape(),
+ _ctx.at(hidden_state_out_index).type()));
+ _builder.addShapeConstr(input_index,
+ asTensorInfo(_ctx.at(input_index).shape(), _ctx.at(input_index).type()));
+ _builder.addShapeConstr(
+ weights_index, asTensorInfo(_ctx.at(weights_index).shape(), _ctx.at(weights_index).type()));
+ _builder.addShapeConstr(recurrent_weights_index,
+ asTensorInfo(_ctx.at(recurrent_weights_index).shape(),
+ _ctx.at(recurrent_weights_index).type()));
+ _builder.addShapeConstr(bias_index,
+ asTensorInfo(_ctx.at(bias_index).shape(), _ctx.at(bias_index).type()));
+ _builder.addShapeConstr(
+ hidden_state_in_index,
+ asTensorInfo(_ctx.at(hidden_state_in_index).shape(), _ctx.at(hidden_state_in_index).type()));
+
+ // Construct operation parameters
+ struct Param
+ {
+ int output_index;
+ int hidden_state_out_index;
+
+ int input_index;
+ int weights_index;
+ int recurrent_weights_index;
+ int bias_index;
+ int hidden_state_in_index;
+
+ FuseCode activation;
+ };
+
+ Param param;
+
+ param.output_index = output_index.asInt();
+ param.hidden_state_out_index = hidden_state_out_index.asInt();
+
+ param.input_index = input_index.asInt();
+ param.weights_index = weights_index.asInt();
+ param.recurrent_weights_index = recurrent_weights_index.asInt();
+ param.bias_index = bias_index.asInt();
+ param.hidden_state_in_index = hidden_state_in_index.asInt();
+ param.activation = static_cast<FuseCode>(_ctx.at(fused_activation_index).asScalar<int32_t>());
+
+ auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+ auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
+ auto hidden_state_out_alloc =
+ ctx.at(::internal::tflite::operand::Index{param.hidden_state_out_index});
+ auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
+ auto weights_alloc = ctx.at(::internal::tflite::operand::Index{param.weights_index});
+ auto recurrent_weights_alloc =
+ ctx.at(::internal::tflite::operand::Index{param.recurrent_weights_index});
+ auto bias_alloc = ctx.at(::internal::tflite::operand::Index{param.bias_index});
+ auto hidden_state_in_alloc =
+ ctx.at(::internal::tflite::operand::Index{param.hidden_state_in_index});
+ auto act_info = asActivationInfo(param.activation);
+
+ if (::internal::arm_compute::isGpuMode())
+ {
+ std::unique_ptr<::arm_compute::CLCopy> copy_fn{new ::arm_compute::CLCopy};
+ copy_fn->configure(CAST_CL(hidden_state_in_alloc), CAST_CL(hidden_state_out_alloc));
+ builder.append("COPY", std::move(copy_fn));
+
+ std::unique_ptr<::arm_compute::CLRNNLayer> rnn_fn{new ::arm_compute::CLRNNLayer};
+
+ // The hidden_state_in's data must be copied to hidden_state_out_alloc before fn->run() is
+ // performed.
+ rnn_fn->configure(CAST_CL(input_alloc), CAST_CL(weights_alloc),
+ CAST_CL(recurrent_weights_alloc), CAST_CL(bias_alloc),
+ CAST_CL(hidden_state_out_alloc), CAST_CL(output_alloc), act_info);
+
+ builder.append("RNN", std::move(rnn_fn));
+ }
+ else
+ throw std::runtime_error("Not supported, yet");
+ };
+
+ _builder.addStage(stage);
}
void Planner::visit(const ::internal::tflite::op::Floor::Node &node)