Release 18.02
[platform/upstream/armnn.git] / src / armnn / backends / ArmComputeTensorUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #include "ArmComputeTensorUtils.hpp"
6 #include "ArmComputeUtils.hpp"
7
8 #include <armnn/Descriptors.hpp>
9
10 namespace armnn
11 {
12 namespace armcomputetensorutils
13 {
14
15 arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
16 {
17     switch(dataType)
18     {
19         case armnn::DataType::Float32:
20         {
21             return arm_compute::DataType::F32;
22         }
23         case armnn::DataType::QuantisedAsymm8:
24         {
25             return arm_compute::DataType::QASYMM8;
26         }
27         case armnn::DataType::Signed32:
28         {
29             return arm_compute::DataType::S32;
30         }
31         default:
32         {
33             BOOST_ASSERT_MSG(false, "Unknown data type");
34             return arm_compute::DataType::UNKNOWN;
35         }
36     }
37 }
38
39 arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
40 {
41     arm_compute::TensorShape shape;
42
43     // armnn tensors are (batch, channels, height, width)
44     // arm_compute tensors are (width, height, channels, batch)
45     for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
46     {
47         // note that our dimensions are stored in the opposite order to ACL's
48         shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
49
50         // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
51         // arm_compute tensors expect this
52     }
53
54     // prevent arm_compute issue where tensor is flattened to nothing
55     if (shape.num_dimensions() == 0)
56     {
57         shape.set_num_dimensions(1);
58     }
59
60     return shape;
61 }
62
63 // Utility function used to build a TensorInfo object, that can be used to initialise
64 // ARM Compute Tensor and CLTensor allocators.
65 arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
66 {
67     const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
68     const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
69     const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
70                                                             tensorInfo.GetQuantizationOffset());
71
72     return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
73 }
74
75 arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
76 {
77     using arm_compute::PoolingType;
78     using arm_compute::DimensionRoundingType;
79     using arm_compute::PadStrideInfo;
80     using arm_compute::PoolingLayerInfo;
81
82     // Resolve ARM Compute layer parameters
83     const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
84     const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
85                                                                                     descriptor.m_OutputShapeRounding);
86
87     const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
88                                       descriptor.m_StrideY,
89                                       descriptor.m_PadLeft,
90                                       descriptor.m_PadRight,
91                                       descriptor.m_PadTop,
92                                       descriptor.m_PadBottom,
93                                       rounding);
94
95     const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
96
97     return arm_compute::PoolingLayerInfo(poolingType, descriptor.m_PoolWidth, padStrideInfo, excludePadding);
98 }
99
100 arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
101 {
102     const arm_compute::NormType normType =
103         ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
104     return arm_compute::NormalizationLayerInfo(normType,
105                                                descriptor.m_NormSize,
106                                                descriptor.m_Alpha,
107                                                descriptor.m_Beta,
108                                                descriptor.m_K,
109                                                false);
110 }
111
112 arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
113 {
114     arm_compute::PermutationVector aclPerm;
115
116     unsigned int start = 0;
117     while ((start == perm[start]) && (start < perm.GetSize()))
118     {
119         ++start;
120     }
121
122     for (unsigned int i = start; i < perm.GetSize(); ++i)
123     {
124         aclPerm.set(i - start, perm[i] - start);
125     }
126
127     return aclPerm;
128 }
129
130 } // namespace armcomputetensorutils
131 } // namespace armnn