From dd762a37d56f47306f2cee254bddaeb09f12a8ed Mon Sep 17 00:00:00 2001 From: =?utf8?q?Shubham=20Gupta/System=20SW=20/SRI-Bangalore/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 6 Nov 2018 06:35:23 +0530 Subject: [PATCH] Add CL kernel call for Transpose op from runtime (#3470) This patch adds cl kernel call for transpose op from PACL. Signed-off-by: shubham --- runtimes/pure_arm_compute/src/compilation.cc | 35 ++++++++++++++++++---------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/runtimes/pure_arm_compute/src/compilation.cc b/runtimes/pure_arm_compute/src/compilation.cc index eafd239..493208c 100644 --- a/runtimes/pure_arm_compute/src/compilation.cc +++ b/runtimes/pure_arm_compute/src/compilation.cc @@ -52,7 +52,7 @@ #include #include #include -#include +#include #include #include #include @@ -3406,13 +3406,13 @@ void Planner::visit(const ::internal::tflite::op::LSTM::Node &node) void Planner::visit(const ::internal::tflite::op::Transpose::Node &node) { VERBOSE(Transpose) << "Configure Transpose operation" << std::endl; - // Transpose supports only height-wight dimention support. - // CLPermute can be used to implement generic transpose along any axis - // But CLPermute only implements [2,0,1], [1,2,0], [3,2,0,1] - // TODO Implement other permutation CLPermute function and provide generic transpose const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index}; const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index}; + const ::internal::tflite::operand::Index permu_index{node.param().permu_index}; + + assert(_ctx.at(ifm_index).shape().rank() == _ctx.at(ofm_index).shape().rank()); + assert(_ctx.at(permu_index).hasData() == true); // Set shape constraints _builder.addShapeConstr( @@ -3421,30 +3421,41 @@ void Planner::visit(const ::internal::tflite::op::Transpose::Node &node) _builder.addShapeConstr( ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(), _ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint())); - // NNAPI spec provides permutation vector for generic transpose - // TODO Make the permutation vector a part of Param + struct Param { int ofm_index; int ifm_index; + const int32_t *pv; + int rank; }; Param param; param.ofm_index = ofm_index.asInt(); param.ifm_index = ifm_index.asInt(); + param.pv = reinterpret_cast(_ctx.at(permu_index).data().base()); + param.rank = _ctx.at(ifm_index).shape().rank(); auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) { + auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index}); const auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index}); - // CLTranspose assumes only spatial transpose, will be replaced with CLPermute - // TODO Check the validity of permutation vector, then call CLPermute with permu vector - auto fn = nnfw::make_unique<::arm_compute::CLTranspose>(); + if (::internal::arm_compute::isGpuMode()) + { + auto fn = nnfw::make_unique<::arm_compute::CLPermuteEx>(); + + fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc), + getARMComputePermutationVector(param.rank, param.pv)); - fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc)); + builder.append("Transpose", std::move(fn)); + } + else + { + throw std::runtime_error("Not supported, yet"); + } - builder.append("Transpose", std::move(fn)); }; _builder.addStage(stage); -- 2.7.4