2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
8 #include <boost/cast.hpp>
11 #include "TensorCopyUtils.hpp"
13 #ifdef ARMCOMPUTECL_ENABLED
14 #include "backends/ClTensorHandle.hpp"
17 #if ARMCOMPUTENEON_ENABLED
18 #include "backends/NeonTensorHandle.hpp"
21 #if ARMCOMPUTECLENABLED || ARMCOMPUTENEON_ENABLED
22 #include "backends/ArmComputeTensorUtils.hpp"
25 #include "backends/CpuTensorHandle.hpp"
27 void CopyDataToITensorHandle(armnn::ITensorHandle* tensorHandle, const void* mem)
29 switch (tensorHandle->GetType())
31 case armnn::ITensorHandle::Cpu:
33 auto handle = boost::polymorphic_downcast<armnn::ScopedCpuTensorHandle*>(tensorHandle);
34 memcpy(handle->GetTensor<void>(), mem, handle->GetTensorInfo().GetNumBytes());
37 #ifdef ARMCOMPUTECL_ENABLED
38 case armnn::ITensorHandle::CL:
40 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
41 auto handle = boost::polymorphic_downcast<armnn::IClTensorHandle*>(tensorHandle);
43 switch(handle->GetDataType())
45 case arm_compute::DataType::F32:
46 CopyArmComputeITensorData(static_cast<const float*>(mem), handle->GetTensor());
48 case arm_compute::DataType::QASYMM8:
49 CopyArmComputeITensorData(static_cast<const uint8_t*>(mem), handle->GetTensor());
51 case arm_compute::DataType::F16:
52 CopyArmComputeITensorData(static_cast<const armnn::Half*>(mem), handle->GetTensor());
56 throw armnn::UnimplementedException();
63 #if ARMCOMPUTENEON_ENABLED
64 case armnn::ITensorHandle::Neon:
66 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
67 auto handle = boost::polymorphic_downcast<armnn::INeonTensorHandle*>(tensorHandle);
68 switch (handle->GetDataType())
70 case arm_compute::DataType::F32:
71 CopyArmComputeITensorData(static_cast<const float*>(mem), handle->GetTensor());
73 case arm_compute::DataType::QASYMM8:
74 CopyArmComputeITensorData(static_cast<const uint8_t*>(mem), handle->GetTensor());
78 throw armnn::UnimplementedException();
86 throw armnn::UnimplementedException();
91 void CopyDataFromITensorHandle(void* mem, const armnn::ITensorHandle* tensorHandle)
93 switch (tensorHandle->GetType())
95 case armnn::ITensorHandle::Cpu:
97 auto handle = boost::polymorphic_downcast<const armnn::ScopedCpuTensorHandle*>(tensorHandle);
98 memcpy(mem, handle->GetTensor<void>(), handle->GetTensorInfo().GetNumBytes());
101 #ifdef ARMCOMPUTECL_ENABLED
102 case armnn::ITensorHandle::CL:
104 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
105 auto handle = boost::polymorphic_downcast<const armnn::IClTensorHandle*>(tensorHandle);
106 const_cast<armnn::IClTensorHandle*>(handle)->Map(true);
107 switch(handle->GetDataType())
109 case arm_compute::DataType::F32:
110 CopyArmComputeITensorData(handle->GetTensor(), static_cast<float*>(mem));
112 case arm_compute::DataType::QASYMM8:
113 CopyArmComputeITensorData(handle->GetTensor(), static_cast<uint8_t*>(mem));
115 case arm_compute::DataType::F16:
116 CopyArmComputeITensorData(handle->GetTensor(), static_cast<armnn::Half*>(mem));
120 throw armnn::UnimplementedException();
123 const_cast<armnn::IClTensorHandle*>(handle)->Unmap();
127 #if ARMCOMPUTENEON_ENABLED
128 case armnn::ITensorHandle::Neon:
130 using armnn::armcomputetensorutils::CopyArmComputeITensorData;
131 auto handle = boost::polymorphic_downcast<const armnn::INeonTensorHandle*>(tensorHandle);
132 switch (handle->GetDataType())
134 case arm_compute::DataType::F32:
135 CopyArmComputeITensorData(handle->GetTensor(), static_cast<float*>(mem));
137 case arm_compute::DataType::QASYMM8:
138 CopyArmComputeITensorData(handle->GetTensor(), static_cast<uint8_t*>(mem));
142 throw armnn::UnimplementedException();
150 throw armnn::UnimplementedException();
155 void AllocateAndCopyDataToITensorHandle(armnn::ITensorHandle* tensorHandle, const void* mem)
157 tensorHandle->Allocate();
158 CopyDataToITensorHandle(tensorHandle, mem);