IVGCVSW-1946: Remove armnn/src from the include paths
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonConstantWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "NeonConstantWorkload.hpp"
7
8 #include <arm_compute/core/Types.h>
9 #include <Half.hpp>
10 #include <aclCommon/ArmComputeTensorUtils.hpp>
11 #include <neon/NeonTensorHandle.hpp>
12 #include <backendsCommon/CpuTensorHandle.hpp>
13 #include <backendsCommon/Workload.hpp>
14
15 #include <boost/cast.hpp>
16
17 namespace armnn
18 {
19
20 NeonConstantWorkload::NeonConstantWorkload(const ConstantQueueDescriptor& descriptor,
21                                            const WorkloadInfo& info)
22     : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
23     , m_RanOnce(false)
24 {
25 }
26
27 void NeonConstantWorkload::Execute() const
28 {
29     ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonConstantWorkload_Execute");
30
31     using namespace armcomputetensorutils;
32
33     // The intermediate tensor held by the corresponding layer output handler can be initialised with the
34     // given data on the first inference, then reused for subsequent inferences.
35     // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
36     // may not have been configured at the time.
37     if (!m_RanOnce)
38     {
39         const ConstantQueueDescriptor& data = this->m_Data;
40
41         BOOST_ASSERT(data.m_LayerOutput != nullptr);
42         arm_compute::ITensor& output =
43             boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
44         arm_compute::DataType computeDataType =
45             boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
46
47         switch (computeDataType)
48         {
49             case arm_compute::DataType::F16:
50             {
51                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
52                 break;
53             }
54             case arm_compute::DataType::F32:
55             {
56                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
57                 break;
58             }
59             case arm_compute::DataType::QASYMM8:
60             {
61                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
62                 break;
63             }
64             default:
65             {
66                 BOOST_ASSERT_MSG(false, "Unknown data type");
67                 break;
68             }
69         }
70
71         m_RanOnce = true;
72     }
73 }
74
75 } //namespace armnn