Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ArmComputeTensorUtils.cpp
index f88ed2b..8e4abaf 100644 (file)
@@ -16,23 +16,17 @@ arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
 {
     switch(dataType)
     {
+        case armnn::DataType::Float16:
+            return arm_compute::DataType::F16;
         case armnn::DataType::Float32:
-        {
             return arm_compute::DataType::F32;
-        }
         case armnn::DataType::QuantisedAsymm8:
-        {
             return arm_compute::DataType::QASYMM8;
-        }
         case armnn::DataType::Signed32:
-        {
             return arm_compute::DataType::S32;
-        }
         default:
-        {
             BOOST_ASSERT_MSG(false, "Unknown data type");
             return arm_compute::DataType::UNKNOWN;
-        }
     }
 }
 
@@ -40,15 +34,15 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
 {
     arm_compute::TensorShape shape;
 
-    // armnn tensors are (batch, channels, height, width)
-    // arm_compute tensors are (width, height, channels, batch)
+    // armnn tensors are (batch, channels, height, width).
+    // arm_compute tensors are (width, height, channels, batch).
     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
     {
-        // note that our dimensions are stored in the opposite order to ACL's
+        // Note that our dimensions are stored in the opposite order to ACL's.
         shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
 
         // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
-        // arm_compute tensors expect this
+        // arm_compute tensors expect this.
     }
 
     // prevent arm_compute issue where tensor is flattened to nothing
@@ -80,11 +74,18 @@ arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDes
     using arm_compute::PoolingLayerInfo;
     using arm_compute::Size2D;
 
-    // Resolve ARM Compute layer parameters
+    // Resolve ARM Compute layer parameters.
     const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
+
+    bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
+    //use specific constructor if global pooling
+    if(isGlobalPooling)
+    {
+        return arm_compute::PoolingLayerInfo(poolingType);
+    }
+
     const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
                                                                                     descriptor.m_OutputShapeRounding);
-
     const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
                                       descriptor.m_StrideY,
                                       descriptor.m_PadLeft,