IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / aclCommon / ArmComputeTensorUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <aclCommon/ArmComputeTensorUtils.hpp>
6 #include <aclCommon/ArmComputeUtils.hpp>
7
8 #include "armnn/Exceptions.hpp"
9 #include <armnn/Descriptors.hpp>
10
11 namespace armnn
12 {
13 namespace armcomputetensorutils
14 {
15
16 arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
17 {
18     switch(dataType)
19     {
20         case armnn::DataType::Float16:
21             return arm_compute::DataType::F16;
22         case armnn::DataType::Float32:
23             return arm_compute::DataType::F32;
24         case armnn::DataType::QuantisedAsymm8:
25             return arm_compute::DataType::QASYMM8;
26         case armnn::DataType::Signed32:
27             return arm_compute::DataType::S32;
28         default:
29             BOOST_ASSERT_MSG(false, "Unknown data type");
30             return arm_compute::DataType::UNKNOWN;
31     }
32 }
33
34 arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
35 {
36     arm_compute::TensorShape shape;
37
38     // armnn tensors are (batch, channels, height, width).
39     // arm_compute tensors are (width, height, channels, batch).
40     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
41     {
42         // Note that our dimensions are stored in the opposite order to ACL's.
43         shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
44
45         // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
46         // arm_compute tensors expect this.
47     }
48
49     // prevent arm_compute issue where tensor is flattened to nothing
50     if (shape.num_dimensions() == 0)
51     {
52         shape.set_num_dimensions(1);
53     }
54
55     return shape;
56 }
57
58 // Utility function used to build a TensorInfo object, that can be used to initialise
59 // ARM Compute Tensor and CLTensor allocators.
60 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
61 {
62     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
63     const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
64     const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
65                                                             tensorInfo.GetQuantizationOffset());
66
67     return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
68 }
69
70 arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
71 {
72     switch(dataLayout)
73     {
74         case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
75
76         case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
77
78         default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
79                                                 std::to_string(static_cast<int>(dataLayout)) + "]");
80     }
81 }
82
83 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
84                                                   armnn::DataLayout dataLayout)
85 {
86     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
87     const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
88     const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
89                                                             tensorInfo.GetQuantizationOffset());
90
91     arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
92     clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
93
94     return clTensorInfo;
95 }
96
97 arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
98 {
99     using arm_compute::PoolingType;
100     using arm_compute::DimensionRoundingType;
101     using arm_compute::PadStrideInfo;
102     using arm_compute::PoolingLayerInfo;
103     using arm_compute::Size2D;
104
105     // Resolve ARM Compute layer parameters.
106     const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
107
108     bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
109     //use specific constructor if global pooling
110     if(isGlobalPooling)
111     {
112         return arm_compute::PoolingLayerInfo(poolingType);
113     }
114
115     const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
116                                                                                     descriptor.m_OutputShapeRounding);
117     const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
118                                       descriptor.m_StrideY,
119                                       descriptor.m_PadLeft,
120                                       descriptor.m_PadRight,
121                                       descriptor.m_PadTop,
122                                       descriptor.m_PadBottom,
123                                       rounding);
124
125     const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
126
127     const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
128
129     return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding);
130 }
131
132 arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
133 {
134     const arm_compute::NormType normType =
135         ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
136     return arm_compute::NormalizationLayerInfo(normType,
137                                                descriptor.m_NormSize,
138                                                descriptor.m_Alpha,
139                                                descriptor.m_Beta,
140                                                descriptor.m_K,
141                                                false);
142 }
143
144 arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
145 {
146     arm_compute::PermutationVector aclPerm;
147
148     unsigned int start = 0;
149     while ((start < perm.GetSize()) && (start == perm[start]))
150     {
151         ++start;
152     }
153
154     for (unsigned int i = start; i < perm.GetSize(); ++i)
155     {
156         aclPerm.set(i - start, perm[i] - start);
157     }
158
159     return aclPerm;
160 }
161
162 } // namespace armcomputetensorutils
163 } // namespace armnn