Add CL kernel call for Transpose op from runtime (#3470)
authorShubham Gupta/System SW /SRI-Bangalore/Engineer/삼성전자 <shub98.gupta@samsung.com>
Tue, 6 Nov 2018 01:05:23 +0000 (06:35 +0530)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Tue, 6 Nov 2018 01:05:23 +0000 (10:05 +0900)
This patch adds cl kernel call for transpose op from PACL.

Signed-off-by: shubham <shub98.gupta@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index eafd239..493208c 100644 (file)
@@ -52,7 +52,7 @@
 #include <arm_compute/runtime/CL/functions/CLDequantizationLayer.h>
 #include <arm_compute/runtime/CL/functions/CLDepthToSpace.h>
 #include <arm_compute/runtime/CL/functions/CLReductionMean.h>
-#include <arm_compute/runtime/CL/functions/CLTranspose.h>
+#include <arm_compute/runtime/CL/functions/CLPermuteEx.h>
 #include <arm_compute/runtime/CL/functions/CLRNNLayer.h>
 #include <arm_compute/runtime/CL/functions/CLFloor.h>
 #include <arm_compute/runtime/CL/functions/CLCopy.h>
@@ -3406,13 +3406,13 @@ void Planner::visit(const ::internal::tflite::op::LSTM::Node &node)
 void Planner::visit(const ::internal::tflite::op::Transpose::Node &node)
 {
   VERBOSE(Transpose) << "Configure Transpose operation" << std::endl;
-  // Transpose supports only height-wight dimention support.
-  // CLPermute can be used to implement generic transpose along any axis
-  // But CLPermute only implements [2,0,1], [1,2,0], [3,2,0,1]
 
-  // TODO Implement other permutation CLPermute function and provide generic transpose
   const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
   const ::internal::tflite::operand::Index ifm_index{node.param().ifm_index};
+  const ::internal::tflite::operand::Index permu_index{node.param().permu_index};
+
+  assert(_ctx.at(ifm_index).shape().rank() == _ctx.at(ofm_index).shape().rank());
+  assert(_ctx.at(permu_index).hasData() == true);
 
   // Set shape constraints
   _builder.addShapeConstr(
@@ -3421,30 +3421,41 @@ void Planner::visit(const ::internal::tflite::op::Transpose::Node &node)
   _builder.addShapeConstr(
       ifm_index, asTensorInfo(asTensorShape(_ctx.at(ifm_index).shape()), _ctx.at(ifm_index).type(),
                               _ctx.at(ifm_index).scale(), _ctx.at(ifm_index).zeroPoint()));
-  // NNAPI spec provides permutation vector for generic transpose
-  // TODO Make the permutation vector a part of Param
+
   struct Param
   {
     int ofm_index;
     int ifm_index;
+    const int32_t *pv;
+    int rank;
   };
 
   Param param;
 
   param.ofm_index = ofm_index.asInt();
   param.ifm_index = ifm_index.asInt();
+  param.pv = reinterpret_cast<const int32_t *>(_ctx.at(permu_index).data().base());
+  param.rank = _ctx.at(ifm_index).shape().rank();
 
   auto stage = [param](const IAllocationContext &ctx, IExecutionBuilder &builder) {
+
     auto ofm_alloc = ctx.at(::internal::tflite::operand::Index{param.ofm_index});
     const auto ifm_alloc = ctx.at(::internal::tflite::operand::Index{param.ifm_index});
 
-    // CLTranspose assumes only spatial transpose, will be replaced with CLPermute
-    // TODO Check the validity of permutation vector, then call CLPermute with permu vector
-    auto fn = nnfw::make_unique<::arm_compute::CLTranspose>();
+    if (::internal::arm_compute::isGpuMode())
+    {
+      auto fn = nnfw::make_unique<::arm_compute::CLPermuteEx>();
+
+      fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc),
+                    getARMComputePermutationVector(param.rank, param.pv));
 
-    fn->configure(CAST_CL(ifm_alloc), CAST_CL(ofm_alloc));
+      builder.append("Transpose", std::move(fn));
+    }
+    else
+    {
+      throw std::runtime_error("Not supported, yet");
+    }
 
-    builder.append("Transpose", std::move(fn));
   };
 
   _builder.addStage(stage);