nnfw: RSQRT PACL Changes for RSQRT Support (#3040)
authorTANUJ TEKRIWAL/System SW /SRI-Bangalore/Engineer/삼성전자 <tanuj.tekri@samsung.com>
Tue, 16 Oct 2018 01:40:49 +0000 (07:10 +0530)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Tue, 16 Oct 2018 01:40:49 +0000 (10:40 +0900)
This patch adds support for RSQRT in PACL Codebase(compilation.cc)

Signed-off-by: Tanuj Tekriwal <tanuj.tekri@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index e7ae8e1..05a9cba 100644 (file)
@@ -30,6 +30,7 @@
 #include <arm_compute/runtime/CL/functions/CLPixelWiseDivision.h>
 #include <arm_compute/runtime/CL/functions/CLPoolingLayer.h>
 #include <arm_compute/runtime/CL/functions/CLActivationLayer.h>
+#include <arm_compute/runtime/CL/functions/CLActivationLayerEx.h>
 #include <arm_compute/runtime/CL/functions/CLScale.h>
 #include <arm_compute/runtime/CL/functions/CLReshapeLayer.h>
 #include <arm_compute/runtime/CL/functions/CLStridedSlice.h>
@@ -3369,7 +3370,46 @@ void Planner::visit(const ::internal::tflite::op::RSQRT::Node &node)
 {
   VERBOSE(RSQRT) << "Configure Rsqrt operation" << std::endl;
 
-  throw std::runtime_error("Not supported, yet");
+  const ::internal::tflite::operand::Index output_index{node.param().output_index};
+  const ::internal::tflite::operand::Index input_index{node.param().input_index};
+
+  // Set shape constraints
+  _builder.addShapeConstr(output_index, asTensorInfo(asTensorShape(_ctx.at(output_index).shape()),
+                                                     _ctx.at(output_index).type()));
+  _builder.addShapeConstr(input_index, asTensorInfo(asTensorShape(_ctx.at(input_index).shape()),
+                                                    _ctx.at(input_index).type()));
+
+  struct Param
+  {
+    int output_index;
+    int input_index;
+  };
+
+  Param param;
+
+  param.output_index = output_index.asInt();
+  param.input_index = input_index.asInt();
+
+  auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+    auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index});
+    auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index});
+
+    const ::arm_compute::ActivationLayerInfoEx act_info{
+        ::arm_compute::ActivationLayerInfoEx::ActivationFunction::RSQRT};
+
+    if (::internal::arm_compute::isGpuMode())
+    {
+      auto fn = nnfw::make_unique<::arm_compute::CLActivationLayerEx>();
+
+      fn->configure(CAST_CL(input_alloc), CAST_CL(output_alloc), act_info);
+
+      builder.append("RSQRT", std::move(fn));
+    }
+    else
+      throw std::runtime_error("Not supported, yet");
+  };
+
+  _builder.addStage(stage);
 }
 
 void Planner::visit(const ::internal::tflite::op::Equal::Node &node)