From 95857aaf55fab60878793e567e8f914224e765c7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?TANUJ=20TEKRIWAL/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 16 Oct 2018 07:10:49 +0530 Subject: [PATCH] nnfw: RSQRT PACL Changes for RSQRT Support (#3040) This patch adds support for RSQRT in PACL Codebase(compilation.cc) Signed-off-by: Tanuj Tekriwal --- runtimes/pure_arm_compute/src/compilation.cc | 42 +++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index e7ae8e1..05a9cba 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -3369,7 +3370,46 @@ void Planner::visit(const ::internal::tflite::op::RSQRT::Node &node) { VERBOSE(RSQRT) << "Configure Rsqrt operation" << std::endl; - throw std::runtime_error("Not supported, yet"); + const ::internal::tflite::operand::Index output_index{node.param().output_index}; + const ::internal::tflite::operand::Index input_index{node.param().input_index}; + + // Set shape constraints + _builder.addShapeConstr(output_index, asTensorInfo(asTensorShape(_ctx.at(output_index).shape()), + _ctx.at(output_index).type())); + _builder.addShapeConstr(input_index, asTensorInfo(asTensorShape(_ctx.at(input_index).shape()), + _ctx.at(input_index).type())); + + struct Param + { + int output_index; + int input_index; + }; + + Param param; + + param.output_index = output_index.asInt(); + param.input_index = input_index.asInt(); + + auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) { + auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index}); + auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index}); + + const ::arm_compute::ActivationLayerInfoEx act_info{ + ::arm_compute::ActivationLayerInfoEx::ActivationFunction::RSQRT}; + + if (::internal::arm_compute::isGpuMode()) + { + auto fn = nnfw::make_unique<::arm_compute::CLActivationLayerEx>(); + + fn->configure(CAST_CL(input_alloc), CAST_CL(output_alloc), act_info); + + builder.append("RSQRT", std::move(fn)); + } + else + throw std::runtime_error("Not supported, yet"); + }; + + _builder.addStage(stage); } void Planner::visit(const ::internal::tflite::op::Equal::Node &node) -- 2.7.4