Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ArmComputeTensorUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "ArmComputeTensorUtils.hpp"
6 #include "ArmComputeUtils.hpp"
7
8 #include <armnn/Descriptors.hpp>
9
10 namespace armnn
11 {
12 namespace armcomputetensorutils
13 {
14
15 arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
16 {
17     switch(dataType)
18     {
19         case armnn::DataType::Float16:
20             return arm_compute::DataType::F16;
21         case armnn::DataType::Float32:
22             return arm_compute::DataType::F32;
23         case armnn::DataType::QuantisedAsymm8:
24             return arm_compute::DataType::QASYMM8;
25         case armnn::DataType::Signed32:
26             return arm_compute::DataType::S32;
27         default:
28             BOOST_ASSERT_MSG(false, "Unknown data type");
29             return arm_compute::DataType::UNKNOWN;
30     }
31 }
32
33 arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
34 {
35     arm_compute::TensorShape shape;
36
37     // armnn tensors are (batch, channels, height, width).
38     // arm_compute tensors are (width, height, channels, batch).
39     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
40     {
41         // Note that our dimensions are stored in the opposite order to ACL's.
42         shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
43
44         // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
45         // arm_compute tensors expect this.
46     }
47
48     // prevent arm_compute issue where tensor is flattened to nothing
49     if (shape.num_dimensions() == 0)
50     {
51         shape.set_num_dimensions(1);
52     }
53
54     return shape;
55 }
56
57 // Utility function used to build a TensorInfo object, that can be used to initialise
58 // ARM Compute Tensor and CLTensor allocators.
59 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
60 {
61     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
62     const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
63     const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
64                                                             tensorInfo.GetQuantizationOffset());
65
66     return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
67 }
68
69 arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
70 {
71     using arm_compute::PoolingType;
72     using arm_compute::DimensionRoundingType;
73     using arm_compute::PadStrideInfo;
74     using arm_compute::PoolingLayerInfo;
75     using arm_compute::Size2D;
76
77     // Resolve ARM Compute layer parameters.
78     const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
79
80     bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
81     //use specific constructor if global pooling
82     if(isGlobalPooling)
83     {
84         return arm_compute::PoolingLayerInfo(poolingType);
85     }
86
87     const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
88                                                                                     descriptor.m_OutputShapeRounding);
89     const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
90                                       descriptor.m_StrideY,
91                                       descriptor.m_PadLeft,
92                                       descriptor.m_PadRight,
93                                       descriptor.m_PadTop,
94                                       descriptor.m_PadBottom,
95                                       rounding);
96
97     const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
98
99     const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
100
101     return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding);
102 }
103
104 arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
105 {
106     const arm_compute::NormType normType =
107         ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
108     return arm_compute::NormalizationLayerInfo(normType,
109                                                descriptor.m_NormSize,
110                                                descriptor.m_Alpha,
111                                                descriptor.m_Beta,
112                                                descriptor.m_K,
113                                                false);
114 }
115
116 arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
117 {
118     arm_compute::PermutationVector aclPerm;
119
120     unsigned int start = 0;
121     while ((start < perm.GetSize()) && (start == perm[start]))
122     {
123         ++start;
124     }
125
126     for (unsigned int i = start; i < perm.GetSize(); ++i)
127     {
128         aclPerm.set(i - start, perm[i] - start);
129     }
130
131     return aclPerm;
132 }
133
134 } // namespace armcomputetensorutils
135 } // namespace armnn