Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloadUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "Workload.hpp"
8 #include <arm_compute/core/CL/OpenCL.h>
9 #include <arm_compute/runtime/CL/CLFunctions.h>
10 #include <arm_compute/runtime/SubTensor.h>
11 #include "ArmComputeTensorUtils.hpp"
12 #include "OpenClTimer.hpp"
13 #include "CpuTensorHandle.hpp"
14 #include "Half.hpp"
15
16 #define ARMNN_SCOPED_PROFILING_EVENT_CL(name) \
17     ARMNN_SCOPED_PROFILING_EVENT_WITH_INSTRUMENTS(armnn::Compute::GpuAcc, \
18                                                   name, \
19                                                   armnn::OpenClTimer(), \
20                                                   armnn::WallClockTimer())
21
22 namespace armnn
23 {
24
25 template <typename T>
26 void CopyArmComputeClTensorData(const T* srcData, arm_compute::CLTensor& dstTensor)
27 {
28     {
29         ARMNN_SCOPED_PROFILING_EVENT_CL("MapClTensorForWriting");
30         dstTensor.map(true);
31     }
32
33     {
34         ARMNN_SCOPED_PROFILING_EVENT_CL("CopyToClTensor");
35         armcomputetensorutils::CopyArmComputeITensorData<T>(srcData, dstTensor);
36     }
37
38     dstTensor.unmap();
39 }
40
41 template <typename T>
42 void InitialiseArmComputeClTensorData(arm_compute::CLTensor& clTensor, const T* data)
43 {
44     armcomputetensorutils::InitialiseArmComputeTensorEmpty(clTensor);
45     CopyArmComputeClTensorData<T>(data, clTensor);
46 }
47
48 inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& clTensor,
49                                                           const ConstCpuTensorHandle *handle)
50 {
51     BOOST_ASSERT(handle);
52     switch(handle->GetTensorInfo().GetDataType())
53     {
54         case DataType::Float16:
55             InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<armnn::Half>());
56             break;
57         case DataType::Float32:
58             InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
59             break;
60         default:
61             BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
62     }
63 };
64
65 } //namespace armnn