2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
5 #include "ArmComputeTensorUtils.hpp"
6 #include "ArmComputeUtils.hpp"
8 #include <armnn/Descriptors.hpp>
12 namespace armcomputetensorutils
15 arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
19 case armnn::DataType::Float16:
20 return arm_compute::DataType::F16;
21 case armnn::DataType::Float32:
22 return arm_compute::DataType::F32;
23 case armnn::DataType::QuantisedAsymm8:
24 return arm_compute::DataType::QASYMM8;
25 case armnn::DataType::Signed32:
26 return arm_compute::DataType::S32;
28 BOOST_ASSERT_MSG(false, "Unknown data type");
29 return arm_compute::DataType::UNKNOWN;
33 arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
35 arm_compute::TensorShape shape;
37 // armnn tensors are (batch, channels, height, width).
38 // arm_compute tensors are (width, height, channels, batch).
39 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
41 // Note that our dimensions are stored in the opposite order to ACL's.
42 shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
44 // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
45 // arm_compute tensors expect this.
48 // prevent arm_compute issue where tensor is flattened to nothing
49 if (shape.num_dimensions() == 0)
51 shape.set_num_dimensions(1);
57 // Utility function used to build a TensorInfo object, that can be used to initialise
58 // ARM Compute Tensor and CLTensor allocators.
59 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
61 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
62 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
63 const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
64 tensorInfo.GetQuantizationOffset());
66 return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
69 arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
71 using arm_compute::PoolingType;
72 using arm_compute::DimensionRoundingType;
73 using arm_compute::PadStrideInfo;
74 using arm_compute::PoolingLayerInfo;
75 using arm_compute::Size2D;
77 // Resolve ARM Compute layer parameters.
78 const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
80 bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
81 //use specific constructor if global pooling
84 return arm_compute::PoolingLayerInfo(poolingType);
87 const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
88 descriptor.m_OutputShapeRounding);
89 const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
92 descriptor.m_PadRight,
94 descriptor.m_PadBottom,
97 const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
99 const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
101 return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding);
104 arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
106 const arm_compute::NormType normType =
107 ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
108 return arm_compute::NormalizationLayerInfo(normType,
109 descriptor.m_NormSize,
116 arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
118 arm_compute::PermutationVector aclPerm;
120 unsigned int start = 0;
121 while ((start < perm.GetSize()) && (start == perm[start]))
126 for (unsigned int i = start; i < perm.GetSize(); ++i)
128 aclPerm.set(i - start, perm[i] - start);
134 } // namespace armcomputetensorutils