From: TANUJ TEKRIWAL/System SW /SRI-Bangalore/Engineer/삼성전자 Date: Tue, 16 Oct 2018 01:40:49 +0000 (+0530) Subject: nnfw: RSQRT PACL Changes for RSQRT Support (#3040) X-Git-Tag: 0.3~631 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=95857aaf55fab60878793e567e8f914224e765c7;p=platform%2Fcore%2Fml%2Fnnfw.git nnfw: RSQRT PACL Changes for RSQRT Support (#3040) This patch adds support for RSQRT in PACL Codebase(compilation.cc) Signed-off-by: Tanuj Tekriwal --- diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index e7ae8e1..05a9cba 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -3369,7 +3370,46 @@ void Planner::visit(const ::internal::tflite::op::RSQRT::Node &node) { VERBOSE(RSQRT) << "Configure Rsqrt operation" << std::endl; - throw std::runtime_error("Not supported, yet"); + const ::internal::tflite::operand::Index output_index{node.param().output_index}; + const ::internal::tflite::operand::Index input_index{node.param().input_index}; + + // Set shape constraints + _builder.addShapeConstr(output_index, asTensorInfo(asTensorShape(_ctx.at(output_index).shape()), + _ctx.at(output_index).type())); + _builder.addShapeConstr(input_index, asTensorInfo(asTensorShape(_ctx.at(input_index).shape()), + _ctx.at(input_index).type())); + + struct Param + { + int output_index; + int input_index; + }; + + Param param; + + param.output_index = output_index.asInt(); + param.input_index = input_index.asInt(); + + auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) { + auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index}); + auto input_alloc = ctx.at(::internal::tflite::operand::Index{param.input_index}); + + const ::arm_compute::ActivationLayerInfoEx act_info{ + ::arm_compute::ActivationLayerInfoEx::ActivationFunction::RSQRT}; + + if (::internal::arm_compute::isGpuMode()) + { + auto fn = nnfw::make_unique<::arm_compute::CLActivationLayerEx>(); + + fn->configure(CAST_CL(input_alloc), CAST_CL(output_alloc), act_info); + + builder.append("RSQRT", std::move(fn)); + } + else + throw std::runtime_error("Not supported, yet"); + }; + + _builder.addStage(stage); } void Planner::visit(const ::internal::tflite::op::Equal::Node &node)