{
switch(dataType)
{
+ case armnn::DataType::Float16:
+ return arm_compute::DataType::F16;
case armnn::DataType::Float32:
- {
return arm_compute::DataType::F32;
- }
case armnn::DataType::QuantisedAsymm8:
- {
return arm_compute::DataType::QASYMM8;
- }
case armnn::DataType::Signed32:
- {
return arm_compute::DataType::S32;
- }
default:
- {
BOOST_ASSERT_MSG(false, "Unknown data type");
return arm_compute::DataType::UNKNOWN;
- }
}
}
{
arm_compute::TensorShape shape;
- // armnn tensors are (batch, channels, height, width)
- // arm_compute tensors are (width, height, channels, batch)
+ // armnn tensors are (batch, channels, height, width).
+ // arm_compute tensors are (width, height, channels, batch).
for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
{
- // note that our dimensions are stored in the opposite order to ACL's
+ // Note that our dimensions are stored in the opposite order to ACL's.
shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
// TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
- // arm_compute tensors expect this
+ // arm_compute tensors expect this.
}
// prevent arm_compute issue where tensor is flattened to nothing
using arm_compute::PoolingLayerInfo;
using arm_compute::Size2D;
- // Resolve ARM Compute layer parameters
+ // Resolve ARM Compute layer parameters.
const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
+
+ bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
+ //use specific constructor if global pooling
+ if(isGlobalPooling)
+ {
+ return arm_compute::PoolingLayerInfo(poolingType);
+ }
+
const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
descriptor.m_OutputShapeRounding);
-
const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
descriptor.m_StrideY,
descriptor.m_PadLeft,