Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / WorkloadUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include "armnn/Tensor.hpp"
9 #include "ITensorHandle.hpp"
10
11 #include <boost/cast.hpp>
12
13 namespace armnn
14 {
15 namespace
16 {
17 template<typename ArrayType, typename Arg>
18 void AssignValues(unsigned int num, unsigned int& idx, const ArrayType& array, Arg& arg)
19 {
20  if (idx >= num)
21  {
22      return;
23  }
24
25  arg = array[(num - 1) - idx];
26  idx++;
27 };
28
29 template<typename T, typename ArrayType, typename ...Args>
30 void AssignValues(unsigned int num, unsigned int idx, const ArrayType& array, T& assignee, Args& ... args)
31 {
32  AssignValues(num, idx, array, assignee);
33
34  AssignValues(num, idx, array, args...);
35 }
36 } // namespace
37
38 template<typename CopyFunc>
39 void CopyTensorContentsGeneric(const ITensorHandle* srcTensor, ITensorHandle* dstTensor, CopyFunc copy)
40 {
41     static_assert(MaxNumOfTensorDimensions == 4, "Please update CopyTensorContents");
42
43     TensorShape srcStrides = srcTensor->GetStrides();
44     const TensorShape& srcShape = srcTensor->GetShape();
45     TensorShape dstStrides = dstTensor->GetStrides();
46     const TensorShape& dstShape = dstTensor->GetShape();
47
48     size_t srcBatches = 1;
49     size_t srcChannels = 1;
50     size_t srcHeight = 1;
51     size_t srcWidth = 1;
52     AssignValues(srcShape.GetNumDimensions(),0, srcShape,
53                  srcWidth,
54                  srcHeight,
55                  srcChannels,
56                  srcBatches);
57
58     size_t srcBatchStride = 0;
59     size_t srcChannelStride = 0;
60     size_t srcHeightStride = 0;
61     size_t srcWidthStride = 0;
62     AssignValues(srcStrides.GetNumDimensions(),0, srcStrides,
63                  srcWidthStride,
64                  srcHeightStride,
65                  srcChannelStride,
66                  srcBatchStride);
67
68     size_t dstBatches = 1;
69     size_t dstChannels = 1;
70     size_t dstHeight = 1;
71     size_t dstWidth = 1;
72     AssignValues(dstShape.GetNumDimensions(),0, dstShape,
73                  dstWidth,
74                  dstHeight,
75                  dstChannels,
76                  dstBatches);
77
78     size_t dstBatchStride = 0;
79     size_t dstChannelStride = 0;
80     size_t dstHeightStride = 0;
81     size_t dstWidthStride = 0;
82     AssignValues(dstStrides.GetNumDimensions(),0, dstStrides,
83                  dstWidthStride,
84                  dstHeightStride,
85                  dstChannelStride,
86                  dstBatchStride);
87
88     auto srcData = static_cast<const uint8_t*>(srcTensor->Map());
89     auto dstData = static_cast<uint8_t*>(dstTensor->Map());
90
91     size_t copyLength = std::min(srcWidth*srcWidthStride, dstWidth*dstWidthStride);
92     size_t copyHeight = std::min(srcHeight, dstHeight);
93     size_t copyChannels = std::min(srcChannels, dstChannels);
94     size_t copyBatches = std::min(srcBatches, dstBatches);
95
96     for(unsigned int b=0; b < copyBatches; ++b)
97     {
98         auto srcPtrBatch = srcData;
99         auto dstPtrBatch = dstData;
100         for (unsigned int c=0; c< copyChannels; ++c)
101         {
102             auto srcPtrChannel = srcData;
103             auto dstPtrChannel = dstData;
104             for (unsigned int h=0; h < copyHeight; ++h)
105             {
106                 copy(dstData, srcData, copyLength);
107                 dstData += dstHeightStride;
108                 srcData += srcHeightStride;
109             }
110             dstData += (static_cast<long>(dstChannelStride) - (dstData - dstPtrChannel));
111             srcData += (static_cast<long>(srcChannelStride) - (srcData - srcPtrChannel));
112         }
113         dstData += (static_cast<long>(dstBatchStride)-(dstData - dstPtrBatch));
114         srcData += (static_cast<long>(srcBatchStride)-(srcData - srcPtrBatch));
115     }
116
117     srcTensor->Unmap();
118     dstTensor->Unmap();
119 }
120
121 template <typename SrcTensorHandleType, typename DstTensorHandleType, typename DescriptorType>
122 void GatherTensorHandlePairs(const DescriptorType& descriptor,
123                              std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
124 {
125     const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
126     tensorHandlePairs.reserve(numInputs);
127
128     for (unsigned int i = 0; i < numInputs; ++i)
129     {
130         SrcTensorHandleType* const srcTensorHandle = boost::polymorphic_downcast<SrcTensorHandleType*>(
131             descriptor.m_Inputs[i]);
132         DstTensorHandleType* const dstTensorHandle = boost::polymorphic_downcast<DstTensorHandleType*>(
133             descriptor.m_Outputs[i]);
134
135         tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
136     }
137 }
138
139 } //namespace armnn