2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
5 #include "ArmComputeTensorUtils.hpp"
6 #include "ArmComputeUtils.hpp"
8 #include <armnn/Descriptors.hpp>
12 namespace armcomputetensorutils
15 arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
19 case armnn::DataType::Float32:
21 return arm_compute::DataType::F32;
23 case armnn::DataType::QuantisedAsymm8:
25 return arm_compute::DataType::QASYMM8;
27 case armnn::DataType::Signed32:
29 return arm_compute::DataType::S32;
33 BOOST_ASSERT_MSG(false, "Unknown data type");
34 return arm_compute::DataType::UNKNOWN;
39 arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
41 arm_compute::TensorShape shape;
43 // armnn tensors are (batch, channels, height, width)
44 // arm_compute tensors are (width, height, channels, batch)
45 for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
47 // note that our dimensions are stored in the opposite order to ACL's
48 shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
50 // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
51 // arm_compute tensors expect this
54 // prevent arm_compute issue where tensor is flattened to nothing
55 if (shape.num_dimensions() == 0)
57 shape.set_num_dimensions(1);
63 // Utility function used to build a TensorInfo object, that can be used to initialise
64 // ARM Compute Tensor and CLTensor allocators.
65 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
67 const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
68 const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
69 const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
70 tensorInfo.GetQuantizationOffset());
72 return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
75 arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
77 using arm_compute::PoolingType;
78 using arm_compute::DimensionRoundingType;
79 using arm_compute::PadStrideInfo;
80 using arm_compute::PoolingLayerInfo;
82 // Resolve ARM Compute layer parameters
83 const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
84 const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
85 descriptor.m_OutputShapeRounding);
87 const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
90 descriptor.m_PadRight,
92 descriptor.m_PadBottom,
95 const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
97 return arm_compute::PoolingLayerInfo(poolingType, descriptor.m_PoolWidth, padStrideInfo, excludePadding);
100 arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
102 const arm_compute::NormType normType =
103 ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
104 return arm_compute::NormalizationLayerInfo(normType,
105 descriptor.m_NormSize,
112 arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
114 arm_compute::PermutationVector aclPerm;
116 unsigned int start = 0;
117 while ((start == perm[start]) && (start < perm.GetSize()))
122 for (unsigned int i = start; i < perm.GetSize(); ++i)
124 aclPerm.set(i - start, perm[i] - start);
130 } // namespace armcomputetensorutils