Release 18.03
[platform/upstream/armnn.git] / src / armnn / backends / ArmComputeTensorUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "ArmComputeTensorUtils.hpp"
6 #include "ArmComputeUtils.hpp"
7
8 #include <armnn/Descriptors.hpp>
9
10 namespace armnn
11 {
12 namespace armcomputetensorutils
13 {
14
15 arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
16 {
17     switch(dataType)
18     {
19         case armnn::DataType::Float32:
20         {
21             return arm_compute::DataType::F32;
22         }
23         case armnn::DataType::QuantisedAsymm8:
24         {
25             return arm_compute::DataType::QASYMM8;
26         }
27         case armnn::DataType::Signed32:
28         {
29             return arm_compute::DataType::S32;
30         }
31         default:
32         {
33             BOOST_ASSERT_MSG(false, "Unknown data type");
34             return arm_compute::DataType::UNKNOWN;
35         }
36     }
37 }
38
39 arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
40 {
41     arm_compute::TensorShape shape;
42
43     // armnn tensors are (batch, channels, height, width)
44     // arm_compute tensors are (width, height, channels, batch)
45     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
46     {
47         // note that our dimensions are stored in the opposite order to ACL's
48         shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
49
50         // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
51         // arm_compute tensors expect this
52     }
53
54     // prevent arm_compute issue where tensor is flattened to nothing
55     if (shape.num_dimensions() == 0)
56     {
57         shape.set_num_dimensions(1);
58     }
59
60     return shape;
61 }
62
63 // Utility function used to build a TensorInfo object, that can be used to initialise
64 // ARM Compute Tensor and CLTensor allocators.
65 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
66 {
67     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
68     const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
69     const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
70                                                             tensorInfo.GetQuantizationOffset());
71
72     return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
73 }
74
75 arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
76 {
77     using arm_compute::PoolingType;
78     using arm_compute::DimensionRoundingType;
79     using arm_compute::PadStrideInfo;
80     using arm_compute::PoolingLayerInfo;
81     using arm_compute::Size2D;
82
83     // Resolve ARM Compute layer parameters
84     const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
85     const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
86                                                                                     descriptor.m_OutputShapeRounding);
87
88     const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
89                                       descriptor.m_StrideY,
90                                       descriptor.m_PadLeft,
91                                       descriptor.m_PadRight,
92                                       descriptor.m_PadTop,
93                                       descriptor.m_PadBottom,
94                                       rounding);
95
96     const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
97
98     const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
99
100     return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding);
101 }
102
103 arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
104 {
105     const arm_compute::NormType normType =
106         ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
107     return arm_compute::NormalizationLayerInfo(normType,
108                                                descriptor.m_NormSize,
109                                                descriptor.m_Alpha,
110                                                descriptor.m_Beta,
111                                                descriptor.m_K,
112                                                false);
113 }
114
115 arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
116 {
117     arm_compute::PermutationVector aclPerm;
118
119     unsigned int start = 0;
120     while ((start < perm.GetSize()) && (start == perm[start]))
121     {
122         ++start;
123     }
124
125     for (unsigned int i = start; i < perm.GetSize(); ++i)
126     {
127         aclPerm.set(i - start, perm[i] - start);
128     }
129
130     return aclPerm;
131 }
132
133 } // namespace armcomputetensorutils
134 } // namespace armnn