[Pure CL] Use squeezed shape in setOutput (#1630)
author박종현/동작제어Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 11 Jun 2018 05:31:45 +0000 (14:31 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 11 Jun 2018 05:31:45 +0000 (14:31 +0900)
* [Pure CL] Use squeezed shape in setOutput

This commit simplifies ANeuralNetworksExecution_setOutput via using
squeezed shape.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Update comments

runtimes/pure_arm_compute/src/execution.cc

index 50f6842..1bf7a8b 100644 (file)
@@ -284,6 +284,23 @@ int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution *execution, int32
   return ANEURALNETWORKS_NO_ERROR;
 }
 
+// squeeze(shape) eliminates all the dimensions whose dimensionality is 1
+// For example, squeeze([3, 1, 3]) returns [3, 3]
+static nnfw::util::tensor::Shape squeeze(const nnfw::util::tensor::Shape &shape)
+{
+  nnfw::util::tensor::Shape res(0);
+
+  for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+  {
+    if (shape.dim(axis) != 1)
+    {
+      res.append(shape.dim(axis));
+    }
+  }
+
+  return res;
+}
+
 int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution *execution, int32_t index,
                                        const ANeuralNetworksOperandType *type, void *buffer,
                                        size_t length)
@@ -294,20 +311,15 @@ int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution *execution, int3
 
   const auto operand_index = execution->plan().model().outputs.at(index);
   int32_t output_type = operands.at(operand_index).type();
-  if (operands.at(operand_index).shape().rank() == 1)
-  {
-    const auto len = operands.at(operand_index).shape().dim(0);
+  const auto squeezed_shape = squeeze(operands.at(operand_index).shape());
 
-    asVectorSink(execution, output_type, index, len, buffer, length);
-  }
-  else if ((operands.at(operand_index).shape().rank() == 2) &&
-           (operands.at(operand_index).shape().dim(0) == 1))
+  if (squeezed_shape.rank() == 1)
   {
-    const auto len = operands.at(operand_index).shape().dim(1);
+    const auto len = squeezed_shape.dim(0);
 
     asVectorSink(execution, output_type, index, len, buffer, length);
   }
-  else if (operands.at(operand_index).shape().rank() == 4)
+  else if (squeezed_shape.rank() == 3)
   {
     const auto &operand_shape = operands.at(operand_index).shape().asFeature();