2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "ClConstantWorkload.hpp"
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <cl/ClTensorHandle.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
13 #include "ClWorkloadUtils.hpp"
18 ClConstantWorkload::ClConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
19 : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
24 void ClConstantWorkload::Execute() const
26 ARMNN_SCOPED_PROFILING_EVENT_CL("ClConstantWorkload_Execute");
28 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
29 // on the first inference, then reused for subsequent inferences.
30 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
31 // have been configured at the time.
34 const ConstantQueueDescriptor& data = this->m_Data;
36 BOOST_ASSERT(data.m_LayerOutput != nullptr);
37 arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
38 arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
40 switch (computeDataType)
42 case arm_compute::DataType::F16:
44 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
47 case arm_compute::DataType::F32:
49 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
52 case arm_compute::DataType::QASYMM8:
54 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
59 BOOST_ASSERT_MSG(false, "Unknown data type");