IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / cl / workloads / ClConstantWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClConstantWorkload.hpp"
7
8 #include <Half.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <cl/ClTensorHandle.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12
13 #include "ClWorkloadUtils.hpp"
14
15 namespace armnn
16 {
17
18 ClConstantWorkload::ClConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
19     : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
20     , m_RanOnce(false)
21 {
22 }
23
24 void ClConstantWorkload::Execute() const
25 {
26     ARMNN_SCOPED_PROFILING_EVENT_CL("ClConstantWorkload_Execute");
27
28     // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
29     // on the first inference, then reused for subsequent inferences.
30     // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
31     // have been configured at the time.
32     if (!m_RanOnce)
33     {
34         const ConstantQueueDescriptor& data = this->m_Data;
35
36         BOOST_ASSERT(data.m_LayerOutput != nullptr);
37         arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
38         arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
39
40         switch (computeDataType)
41         {
42             case arm_compute::DataType::F16:
43             {
44                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
45                 break;
46             }
47             case arm_compute::DataType::F32:
48             {
49                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
50                 break;
51             }
52             case arm_compute::DataType::QASYMM8:
53             {
54                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
55                 break;
56             }
57             default:
58             {
59                 BOOST_ASSERT_MSG(false, "Unknown data type");
60                 break;
61             }
62         }
63
64         m_RanOnce = true;
65     }
66 }
67
68 } //namespace armnn