IVGCVSW-2093 Add SpaceToBatchNd layer and corresponding no-op factory implementations
[platform/upstream/armnn.git] / src / backends / MemCopyWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "MemCopyWorkload.hpp"
6 #include "CpuTensorHandle.hpp"
7 #include <TypeUtils.hpp>
8
9 #include <cstring>
10 #include <boost/cast.hpp>
11
12 namespace armnn
13 {
14
15 namespace
16 {
17
18 template <typename SrcTensorHandleType, typename DstTensorHandleType>
19 void GatherTensorHandlePairs(const MemCopyQueueDescriptor& descriptor,
20                              std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
21 {
22     const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
23     tensorHandlePairs.reserve(numInputs);
24
25     for (unsigned int i = 0; i < numInputs; ++i)
26     {
27         SrcTensorHandleType* const srcTensorHandle = boost::polymorphic_downcast<SrcTensorHandleType*>(
28             descriptor.m_Inputs[i]);
29         DstTensorHandleType* const dstTensorHandle = boost::polymorphic_downcast<DstTensorHandleType*>(
30             descriptor.m_Outputs[i]);
31
32         tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
33     }
34 }
35
36 } //namespace
37
38
39 CopyMemGenericWorkload::CopyMemGenericWorkload(const MemCopyQueueDescriptor& descriptor,
40                                                          const WorkloadInfo& info)
41     : BaseWorkload<MemCopyQueueDescriptor>(descriptor, info)
42 {
43     GatherTensorHandlePairs(descriptor, m_TensorHandlePairs);
44 }
45
46 void CopyMemGenericWorkload::Execute() const
47 {
48     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CopyMemGeneric_Execute");
49
50     auto copyFunc = [](void* dst, const void* src, size_t size)
51         {
52             memcpy(dst, src, size);
53         };
54
55     for (const auto& pair : m_TensorHandlePairs)
56     {
57         CopyTensorContentsGeneric(pair.first, pair.second, copyFunc);
58     }
59 }
60
61 } //namespace armnn