Make TopK support 1D, 2D tensor (#2323)
author윤현식/동작제어Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Mon, 20 Aug 2018 02:55:30 +0000 (11:55 +0900)
committer박세희/동작제어Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Mon, 20 Aug 2018 02:55:30 +0000 (11:55 +0900)
Previous TopK supports 1D only.
This commit makes TopK support 2D also in runtime layer.
(ACL layer PR : https://github.sec.samsung.net/RS7-RuntimeNTools/ComputeLibrary/pull/41)

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
runtimes/pure_arm_compute/src/compilation.cc

index d519614..d75cd99 100644 (file)
@@ -2357,11 +2357,12 @@ void Planner::visit(const ::internal::tflite::op::TopKV2::Node &node)
   const ::internal::tflite::operand::Index k_index{node.param().k_index};
 
   // Currently, we only support the vector input.
-  assert(_ctx.at(inputData_index).shape().rank() == 1);
+  assert(_ctx.at(inputData_index).shape().rank() == 1 ||
+         _ctx.at(inputData_index).shape().rank() == 2);
 
-  const auto outputValues_shape = _ctx.at(outputValues_index).shape().asVector();
-  const auto outputIndices_shape = _ctx.at(outputIndices_index).shape().asVector();
-  const auto inputData_shape = _ctx.at(inputData_index).shape().asVector();
+  const auto outputValues_shape = _ctx.at(outputValues_index).shape().asTensor();
+  const auto outputIndices_shape = _ctx.at(outputIndices_index).shape().asTensor();
+  const auto inputData_shape = _ctx.at(inputData_index).shape().asTensor();
   const int32_t k = _ctx.at(k_index).asScalar<int32_t>();
 
   // Set shape constraints