From 59b0c871620694ff45f30861cacf95d9ed4afc37 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Prasanna=20R/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 19 Oct 2018 05:56:05 +0530 Subject: [PATCH] Add CL Kernel calls for Equal op from runtime. (#3107) This patch adds CL Kernel calls for Equal op from runtime. Signed-off-by: prasannar --- runtimes/pure_arm_compute/src/compilation.cc | 54 ++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index 229683f..97fd4e4 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -57,6 +57,7 @@ #include #include #include +#include #include #include @@ -3456,9 +3457,58 @@ void Planner::visit(const ::internal::tflite::op::RSQRT::Node &node) void Planner::visit(const ::internal::tflite::op::Equal::Node &node) { - VERBOSE(Equal) << "Configure Equal operation" << std::endl; + const ::internal::tflite::operand::Index output_index{node.param().output_index}; + const ::internal::tflite::operand::Index input1_index{node.param().input1_index}; + const ::internal::tflite::operand::Index input2_index{node.param().input2_index}; - throw std::runtime_error("Not supported, yet"); + // Set Shape Constraints and TensorInfo + _builder.addShapeConstr(output_index, + asTensorInfo(asTensorShape(_ctx.at(output_index).shape(), false), + _ctx.at(output_index).type(), _ctx.at(output_index).scale(), + _ctx.at(output_index).zeroPoint())); + _builder.addShapeConstr(input1_index, + asTensorInfo(asTensorShape(_ctx.at(input1_index).shape(), false), + _ctx.at(input1_index).type(), _ctx.at(input1_index).scale(), + _ctx.at(input1_index).zeroPoint())); + _builder.addShapeConstr(input2_index, + asTensorInfo(asTensorShape(_ctx.at(input2_index).shape(), false), + _ctx.at(input2_index).type(), _ctx.at(input2_index).scale(), + _ctx.at(input2_index).zeroPoint())); + + // Construct operation parameters + struct Param + { + int output_index; + int input1_index; + int input2_index; + }; + + Param param; + + param.output_index = output_index.asInt(); + param.input1_index = input1_index.asInt(); + param.input2_index = input2_index.asInt(); + auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) { + auto output_alloc = ctx.at(::internal::tflite::operand::Index{param.output_index}); + auto input1_alloc = ctx.at(::internal::tflite::operand::Index{param.input1_index}); + auto input2_alloc = ctx.at(::internal::tflite::operand::Index{param.input2_index}); + + if (::internal::arm_compute::isGpuMode()) + { + auto fn = nnfw::make_unique<::arm_compute::CLEqual>(); + + fn->configure(CAST_CL(input1_alloc), CAST_CL(input2_alloc), CAST_CL(output_alloc)); + + builder.append("Equal", std::move(fn)); + } + else + { + // TODO Add NEON support + + throw std::runtime_error("Not supported, yet"); + } + }; + _builder.addStage(stage); } void Planner::visit(const ::internal::tflite::op::TransposeConv::Node &node) -- 2.7.4