MLCE-347 Bug fixes in Reduce: QueueDescriptor.validate and init REDUCE_MIN
authorTeresa Charlin <teresa.charlinreyes@arm.com>
Thu, 11 Feb 2021 23:05:40 +0000 (23:05 +0000)
committerTeresa Charlin <teresa.charlinreyes@arm.com>
Thu, 11 Feb 2021 23:07:30 +0000 (23:07 +0000)
* Allow input tensors of any rank in ReduceQueueDescriptor::validate
* Fix VTS tests failing for REDUCE_MIN due to initialization

Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Id8fba1662ade4e0a967093fe5a53b275847f2393

src/backends/backendsCommon/WorkloadData.cpp
src/backends/reference/workloads/Reduce.cpp

index b51099f..90db57f 100644 (file)
@@ -3643,8 +3643,6 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
 
-    ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
-
     std::vector<DataType> supportedTypes =
     {
         DataType::BFloat16,
index 31c6262..392ef8e 100644 (file)
@@ -81,17 +81,21 @@ void Reduce(const TensorInfo& inputInfo,
 
     // Initialise temp output
     std::vector<float> tempOut(numOutputs);
-    if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min)
+    switch(reduceOperation)
     {
-        for (unsigned int idx = 0; idx < numOutputs; ++idx)
-        {
-            input[idx];
-            tempOut[idx] = input.Get();
-        }
-    }
-    else
-    {
-        std::fill(tempOut.begin(), tempOut.end(), 0.0);
+        case ReduceOperation::Mean:
+        case ReduceOperation::Sum:
+            std::fill(tempOut.begin(), tempOut.end(), 0.0);
+            break;
+        case ReduceOperation::Max:
+            std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
+            break;
+        case ReduceOperation::Min:
+            std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
+            break;
+        default:
+            throw armnn::InvalidArgumentException("Unknown reduce method: " +
+                std::to_string(static_cast<int>(reduceOperation)));
     }
 
     // Initialise temp index