Bug fix for output shape in Permute (#4201)
authorShubham Gupta/SNAP /SRI-Bangalore/Engineer/삼성전자 <shub98.gupta@samsung.com>
Mon, 14 Jan 2019 08:16:44 +0000 (13:46 +0530)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 14 Jan 2019 08:16:44 +0000 (17:16 +0900)
This patch will fix the bug in calc output shape in Permute

Signed-off-by: shubham <shub98.gupta@samsung.com>
libs/ARMComputeEx/src/core/CL/kernels/CLPermuteExKernel.cpp
runtimes/pure_arm_compute/src/internal/arm_compute/Cast.cc

index 64542cb..8678cce 100644 (file)
@@ -39,10 +39,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output,
       input, 1, DataType::U8, DataType::S8, DataType::QASYMM8, DataType::U16, DataType::S16,
       DataType::U32, DataType::S32, DataType::F16, DataType::F32);
 
-  // output_shape calculation bug
-  // TODO bug fix
-  // const TensorShape output_shape =
-  //    misc::shape_calculator::compute_permutation_output_shape(*input, perm);
+  const TensorShape output_shape =
+      misc::shape_calculator::compute_permutation_output_shape(*input, perm);
 
   ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() < 1 || input->num_dimensions() > 4,
                                   "Permutation upto 4-D input tensor is supported");
@@ -57,7 +55,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output,
   // Validate configured output
   if (output->total_size() != 0)
   {
-    // ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
   }
   return Status{};
@@ -87,10 +85,15 @@ void CLPermuteExKernel::configure(const ICLTensor *input, ICLTensor *output,
   build_opts.emplace("-DDEPTH_IN=" + support::cpp11::to_string(input->info()->dimension(2)));
 
   // New positions of batch(D), height(H), width(w) and channel(C) based on permutation vector
-  build_opts.emplace("-DP1=" + support::cpp11::to_string(perm[0]));
-  build_opts.emplace("-DP2=" + support::cpp11::to_string(perm[1]));
-  build_opts.emplace("-DP3=" + support::cpp11::to_string(perm[2]));
-  build_opts.emplace("-DP4=" + support::cpp11::to_string(perm[3]));
+  build_opts.emplace("-DP1=" +
+                     support::cpp11::to_string((perm.num_dimensions() >= 1) ? perm[0] : 0));
+  build_opts.emplace("-DP2=" +
+                     support::cpp11::to_string((perm.num_dimensions() >= 2) ? perm[1] : 1));
+  build_opts.emplace("-DP3=" +
+                     support::cpp11::to_string((perm.num_dimensions() >= 3) ? perm[2] : 2));
+  build_opts.emplace("-DP4=" +
+                     support::cpp11::to_string((perm.num_dimensions() >= 4) ? perm[3] : 3));
+
   _kernel = static_cast<cl::Kernel>(
       CLKernelLibraryEx::get().create_kernel("permute_generic", build_opts));
 
index ff2f793..46e3925 100644 (file)
@@ -52,7 +52,7 @@
   assert(rank <= 4);
   assert(runtime_pv != nullptr);
 
-  int new_pv[4] = {0};
+  int new_pv[4] = {0, 1, 2, 3};
   ::arm_compute::Coordinates axises = getARMComputeAxises(rank);
 
   if (rank == 4)
     }
   }
 
-  return ::arm_compute::PermutationVector{new_pv[0], new_pv[1], new_pv[2], new_pv[3]};
+  ::arm_compute::PermutationVector ACL_PV =
+      ::arm_compute::PermutationVector{new_pv[0], new_pv[1], new_pv[2], new_pv[3]};
+  ACL_PV.set_num_dimensions(rank);
+
+  return ACL_PV;
 }
 
 ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape,