Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / NeonWorkloads / NeonBaseConstantWorkload.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include <arm_compute/core/Types.h>
9 #include <backends/ArmComputeTensorUtils.hpp>
10 #include <backends/CpuTensorHandle.hpp>
11 #include <backends/NeonTensorHandle.hpp>
12 #include <backends/NeonWorkloadUtils.hpp>
13 #include <backends/Workload.hpp>
14 #include <Half.hpp>
15
16 #include <boost/cast.hpp>
17 #include "Half.hpp"
18
19 namespace armnn
20 {
21
22 // Base class template providing an implementation of the Constant layer common to all data types.
23 template <armnn::DataType... DataFormats>
24 class NeonBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataFormats...>
25 {
26 public:
27     NeonBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
28         : TypedWorkload<ConstantQueueDescriptor, DataFormats...>(descriptor, info)
29         , m_RanOnce(false)
30     {
31     }
32
33     virtual void Execute() const override
34     {
35         using namespace armcomputetensorutils;
36
37         // The intermediate tensor held by the corresponding layer output handler can be initialised with the
38         // given data on the first inference, then reused for subsequent inferences.
39         // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
40         // may not have been configured at the time.
41         if (!m_RanOnce)
42         {
43             const ConstantQueueDescriptor& data = this->m_Data;
44
45             BOOST_ASSERT(data.m_LayerOutput != nullptr);
46             arm_compute::ITensor& output =
47                 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
48             arm_compute::DataType computeDataType =
49                 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
50
51             switch (computeDataType)
52             {
53                 case arm_compute::DataType::F16:
54                 {
55                     CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
56                     break;
57                 }
58                 case arm_compute::DataType::F32:
59                 {
60                     CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
61                     break;
62                 }
63                 case arm_compute::DataType::QASYMM8:
64                 {
65                     CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
66                     break;
67                 }
68                 default:
69                 {
70                     BOOST_ASSERT_MSG(false, "Unknown data type");
71                     break;
72                 }
73             }
74
75             m_RanOnce = true;
76         }
77     }
78
79 private:
80     mutable bool m_RanOnce;
81 };
82
83 } //namespace armnn