IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonWorkloadUtils.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <backendsCommon/Workload.hpp>
8 #include <aclCommon/ArmComputeTensorUtils.hpp>
9 #include <neon/NeonTensorHandle.hpp>
10 #include <neon/NeonTimer.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 #include <arm_compute/runtime/NEON/NEFunctions.h>
13
14 #include <Half.hpp>
15
16 #define ARMNN_SCOPED_PROFILING_EVENT_NEON(name) \
17     ARMNN_SCOPED_PROFILING_EVENT_WITH_INSTRUMENTS(armnn::Compute::CpuAcc, \
18                                                   name, \
19                                                   armnn::NeonTimer(), \
20                                                   armnn::WallClockTimer())
21
22 using namespace armnn::armcomputetensorutils;
23
24 namespace armnn
25 {
26
27 template <typename T>
28 void CopyArmComputeTensorData(arm_compute::Tensor& dstTensor, const T* srcData)
29 {
30     InitialiseArmComputeTensorEmpty(dstTensor);
31     CopyArmComputeITensorData(srcData, dstTensor);
32 }
33
34 inline void InitializeArmComputeTensorData(arm_compute::Tensor& tensor,
35                                            const ConstCpuTensorHandle* handle)
36 {
37     BOOST_ASSERT(handle);
38
39     switch(handle->GetTensorInfo().GetDataType())
40     {
41         case DataType::Float16:
42             CopyArmComputeTensorData(tensor, handle->GetConstTensor<armnn::Half>());
43             break;
44         case DataType::Float32:
45             CopyArmComputeTensorData(tensor, handle->GetConstTensor<float>());
46             break;
47         case DataType::QuantisedAsymm8:
48             CopyArmComputeTensorData(tensor, handle->GetConstTensor<uint8_t>());
49             break;
50         case DataType::Signed32:
51             CopyArmComputeTensorData(tensor, handle->GetConstTensor<int32_t>());
52             break;
53         default:
54             BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
55     }
56 };
57
58 } //namespace armnn