2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "TensorCopyUtils.hpp"
10 #ifdef ARMCOMPUTECL_ENABLED
11 #include <cl/ClTensorHandle.hpp>
14 #if ARMCOMPUTENEON_ENABLED
15 #include <neon/NeonTensorHandle.hpp>
18 #if ARMCOMPUTECLENABLED || ARMCOMPUTENEON_ENABLED
19 #include <aclCommon/ArmComputeTensorUtils.hpp>
22 #include <backendsCommon/CpuTensorHandle.hpp>
24 #include <boost/cast.hpp>
29 void CopyDataToITensorHandle(armnn::ITensorHandle* tensorHandle, const void* mem)
31 switch (tensorHandle->GetType())
33 case armnn::ITensorHandle::Cpu:
35 auto handle = boost::polymorphic_downcast<armnn::ScopedCpuTensorHandle*>(tensorHandle);
36 memcpy(handle->GetTensor<void>(), mem, handle->GetTensorInfo().GetNumBytes());
39 #ifdef ARMCOMPUTECL_ENABLED
40 case armnn::ITensorHandle::CL:
42 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
43 auto handle = boost::polymorphic_downcast<armnn::IClTensorHandle*>(tensorHandle);
45 switch(handle->GetDataType())
47 case arm_compute::DataType::F32:
48 CopyArmComputeITensorData(static_cast<const float*>(mem), handle->GetTensor());
50 case arm_compute::DataType::QASYMM8:
51 CopyArmComputeITensorData(static_cast<const uint8_t*>(mem), handle->GetTensor());
53 case arm_compute::DataType::F16:
54 CopyArmComputeITensorData(static_cast<const armnn::Half*>(mem), handle->GetTensor());
58 throw armnn::UnimplementedException();
65 #if ARMCOMPUTENEON_ENABLED
66 case armnn::ITensorHandle::Neon:
68 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
69 auto handle = boost::polymorphic_downcast<armnn::INeonTensorHandle*>(tensorHandle);
70 switch (handle->GetDataType())
72 case arm_compute::DataType::F32:
73 CopyArmComputeITensorData(static_cast<const float*>(mem), handle->GetTensor());
75 case arm_compute::DataType::QASYMM8:
76 CopyArmComputeITensorData(static_cast<const uint8_t*>(mem), handle->GetTensor());
80 throw armnn::UnimplementedException();
88 throw armnn::UnimplementedException();
93 void CopyDataFromITensorHandle(void* mem, const armnn::ITensorHandle* tensorHandle)
95 switch (tensorHandle->GetType())
97 case armnn::ITensorHandle::Cpu:
99 auto handle = boost::polymorphic_downcast<const armnn::ScopedCpuTensorHandle*>(tensorHandle);
100 memcpy(mem, handle->GetTensor<void>(), handle->GetTensorInfo().GetNumBytes());
103 #ifdef ARMCOMPUTECL_ENABLED
104 case armnn::ITensorHandle::CL:
106 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
107 auto handle = boost::polymorphic_downcast<const armnn::IClTensorHandle*>(tensorHandle);
108 const_cast<armnn::IClTensorHandle*>(handle)->Map(true);
109 switch(handle->GetDataType())
111 case arm_compute::DataType::F32:
112 CopyArmComputeITensorData(handle->GetTensor(), static_cast<float*>(mem));
114 case arm_compute::DataType::QASYMM8:
115 CopyArmComputeITensorData(handle->GetTensor(), static_cast<uint8_t*>(mem));
117 case arm_compute::DataType::F16:
118 CopyArmComputeITensorData(handle->GetTensor(), static_cast<armnn::Half*>(mem));
122 throw armnn::UnimplementedException();
125 const_cast<armnn::IClTensorHandle*>(handle)->Unmap();
129 #if ARMCOMPUTENEON_ENABLED
130 case armnn::ITensorHandle::Neon:
132 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
133 auto handle = boost::polymorphic_downcast<const armnn::INeonTensorHandle*>(tensorHandle);
134 switch (handle->GetDataType())
136 case arm_compute::DataType::F32:
137 CopyArmComputeITensorData(handle->GetTensor(), static_cast<float*>(mem));
139 case arm_compute::DataType::QASYMM8:
140 CopyArmComputeITensorData(handle->GetTensor(), static_cast<uint8_t*>(mem));
144 throw armnn::UnimplementedException();
152 throw armnn::UnimplementedException();
157 void AllocateAndCopyDataToITensorHandle(armnn::ITensorHandle* tensorHandle, const void* mem)
159 tensorHandle->Allocate();
160 CopyDataToITensorHandle(tensorHandle, mem);