Replace uses of non-standard C++:
[platform/upstream/armnn.git] / src / backends / backendsCommon / WorkloadUtils.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "WorkloadUtils.hpp"
7
8 namespace armnn
9 {
10
11 armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor,
12                                  const PermutationVector& permutationVector, void* permuteBuffer)
13 {
14     BOOST_ASSERT_MSG(tensor, "Invalid input tensor");
15     BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
16
17     TensorInfo tensorInfo = tensor->GetTensorInfo();
18
19     if (permutationVector.GetSize() > 0)
20     {
21         tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector);
22         armnnUtils::Permute(tensorInfo.GetShape(), permutationVector,
23                             tensor->GetConstTensor<void>(), permuteBuffer,
24                             GetDataTypeSize(tensorInfo.GetDataType()));
25     }
26     else
27     {
28         ::memcpy(permuteBuffer, tensor->GetConstTensor<void>(), tensorInfo.GetNumBytes());
29     }
30
31     return ConstTensor(tensorInfo, permuteBuffer);
32 }
33
34 void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout)
35 {
36     // Reshape the weights in-place
37     const TensorShape& weightShape = weightInfo.GetShape();
38     switch (dataLayout)
39     {
40         case DataLayout::NHWC:
41             // The data layout is NHWC, reshape from [ H, W, I, M ] to [ 1, H, W, I * M ]
42             weightInfo.SetShape({ 1,
43                                   weightShape[0],
44                                   weightShape[1],
45                                   weightShape[2] * weightShape[3] });
46             weightInfo.SetShape({ 1,
47                                   weightShape[0] * weightShape[1],
48                                   weightShape[2],
49                                   weightShape[3] });
50             break;
51         case DataLayout::NCHW:
52         default:
53             // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ]
54             weightInfo.SetShape({ 1, weightShape[0] * weightShape[1], weightShape[2], weightShape[3] });
55             break;
56     }
57 }
58
59 template <typename DataType>
60 ConstTensor ReorderWeightChannelsForAcl(const ConstTensor& weightHandle, DataLayout dataLayout, void* permuteBuffer)
61 {
62     DataType* weight = static_cast<DataType*>(permuteBuffer);
63     const TensorShape& weightShape = weightHandle.GetShape();
64     unsigned int multiplier;
65     unsigned int height;
66     unsigned int width;
67     unsigned int inputChannels;
68     switch (dataLayout)
69     {
70         case DataLayout::NHWC:    //It actually is [ H, W, I, M ]
71             height        = weightShape[0];
72             width         = weightShape[1];
73             inputChannels = weightShape[2];
74             multiplier    = weightShape[3];
75             break;
76         case DataLayout::NCHW:    //It actually is [ M, I, H, W ]
77         default:
78             height        = weightShape[2];
79             width         = weightShape[3];
80             inputChannels = weightShape[1];
81             multiplier    = weightShape[0];
82             break;
83     }
84
85     std::vector<DataType> weightAclOrder(height*width*inputChannels*multiplier);
86     unsigned int destinationWeightsChannel;
87     unsigned int totalChannels = inputChannels * multiplier;
88     unsigned int channelSize   = height * width;
89
90     for (unsigned int originWeightsChannel = 0; originWeightsChannel < totalChannels; originWeightsChannel++)
91     {
92         if (originWeightsChannel % inputChannels == 0)
93         {
94             destinationWeightsChannel = originWeightsChannel / inputChannels;
95         }
96         else
97         {
98             destinationWeightsChannel = (originWeightsChannel - 1) / inputChannels + multiplier;
99         }
100
101         for (unsigned int i = 0; i < channelSize; i++)
102         {
103             weightAclOrder[i + destinationWeightsChannel * channelSize] =
104                     weight[i + originWeightsChannel * channelSize];
105         }
106     }
107
108     ::memcpy(permuteBuffer, weightAclOrder.data(), weightHandle.GetInfo().GetNumBytes());
109     return ConstTensor(weightHandle.GetInfo(), permuteBuffer);
110 }
111
112 TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout)
113 {
114     // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
115     // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
116
117     // 1. Permute the weights if necessary
118     // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
119     // starting from the current shape of [ M, I, H, W ]
120     TensorInfo weightPermutedInfo(weightInfo);
121     if (dataLayout == DataLayout::NHWC)
122     {
123         // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
124         PermutationVector permutationVector{ 3, 2, 0, 1 };
125         weightPermutedInfo = armnnUtils::Permuted(weightInfo, permutationVector);
126     }
127
128     // 2. Reshape the weights
129     ReshapeWeightsForAcl(weightPermutedInfo, dataLayout);
130
131     // 3. Return the permuted weight info
132     return weightPermutedInfo;
133 }
134
135 armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* weightTensor,
136                                                      DataLayout dataLayout,
137                                                      void* permuteBuffer)
138 {
139     BOOST_ASSERT_MSG(weightTensor, "Invalid input tensor");
140     BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
141
142     auto multiplier    = weightTensor->GetTensorInfo().GetShape()[0];
143     auto inputChannels = weightTensor->GetTensorInfo().GetShape()[1];
144
145     // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
146     // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
147
148     // 1. Permute the weights if necessary
149     // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done
150     // starting from the current shape of [ M, I, H, W ]
151     // If no permutation is necessary, leave the permutation vector empty
152     PermutationVector permutationVector{};
153     if (dataLayout == DataLayout::NHWC)
154     {
155         // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ]
156         permutationVector = { 3, 2, 0, 1 };
157     }
158     ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
159
160     // Shuffle the weights data to obtain the channel order needed used by Acl
161     if (multiplier > 1 && inputChannels > 1 && dataLayout == DataLayout::NCHW)
162     {
163         switch (weightPermuted.GetDataType())
164         {
165             case DataType::Float32:
166                 weightPermuted = ReorderWeightChannelsForAcl<float>(weightPermuted, dataLayout, permuteBuffer);
167                 break;
168             case DataType::Float16:
169                 weightPermuted =
170                     ReorderWeightChannelsForAcl<half_float::half>(weightPermuted, dataLayout, permuteBuffer);
171                 break;
172             case DataType::QuantisedAsymm8:
173                 weightPermuted = ReorderWeightChannelsForAcl<uint8_t>(weightPermuted, dataLayout, permuteBuffer);
174                 break;
175             default:
176                 break;
177         }
178     }
179
180     // 2. Reshape the weights
181     ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout);
182
183     // 3. Return both the tensor and the allocated storage to ensure that the data stays alive
184     return weightPermuted;
185 }
186
187 } // namespace armnn