MLCE-190: Neon and CL Constant Workloads do not support newer DataTypes
[platform/upstream/armnn.git] / src / backends / neon / workloads / NeonConstantWorkload.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "NeonConstantWorkload.hpp"
7
8 #include <arm_compute/core/Types.h>
9 #include <BFloat16.hpp>
10 #include <Half.hpp>
11 #include <aclCommon/ArmComputeTensorUtils.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 #include <neon/NeonTensorHandle.hpp>
14 #include <backendsCommon/CpuTensorHandle.hpp>
15 #include <backendsCommon/Workload.hpp>
16
17 #include <boost/cast.hpp>
18
19 namespace armnn
20 {
21
22 arm_compute::Status NeonConstantWorkloadValidate(const TensorInfo& output)
23 {
24     const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
25
26     std::array<arm_compute::DataType,8> supportedTypes = {
27             arm_compute::DataType::BFLOAT16,
28             arm_compute::DataType::F16,
29             arm_compute::DataType::F32,
30             arm_compute::DataType::QASYMM8,
31             arm_compute::DataType::QASYMM8_SIGNED,
32             arm_compute::DataType::QSYMM16,
33             arm_compute::DataType::QSYMM8,
34             arm_compute::DataType::QSYMM8_PER_CHANNEL
35     };
36     auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
37
38     if (it != end(supportedTypes))
39     {
40         return arm_compute::Status{};
41     }
42     else
43     {
44         return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
45     }
46 }
47
48 NeonConstantWorkload::NeonConstantWorkload(const ConstantQueueDescriptor& descriptor,
49                                            const WorkloadInfo& info)
50     : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
51     , m_RanOnce(false)
52 {
53 }
54
55 void NeonConstantWorkload::Execute() const
56 {
57     ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonConstantWorkload_Execute");
58
59     using namespace armcomputetensorutils;
60
61     // The intermediate tensor held by the corresponding layer output handler can be initialised with the
62     // given data on the first inference, then reused for subsequent inferences.
63     // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
64     // may not have been configured at the time.
65     if (!m_RanOnce)
66     {
67         const ConstantQueueDescriptor& data = this->m_Data;
68
69         ARMNN_ASSERT(data.m_LayerOutput != nullptr);
70         arm_compute::ITensor& output =
71             PolymorphicDowncast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
72         arm_compute::DataType computeDataType =
73             PolymorphicDowncast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
74
75         switch (computeDataType)
76         {
77             case arm_compute::DataType::BFLOAT16:
78             {
79                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<BFloat16>(), output);
80                 break;
81             }
82             case arm_compute::DataType::F16:
83             {
84                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
85                 break;
86             }
87             case arm_compute::DataType::F32:
88             {
89                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
90                 break;
91             }
92             case arm_compute::DataType::QASYMM8:
93             {
94                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
95                 break;
96             }
97             case arm_compute::DataType::QASYMM8_SIGNED:
98             {
99                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
100                 break;
101             }
102             case arm_compute::DataType::QSYMM16:
103             {
104                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int16_t>(), output);
105                 break;
106             }
107             case arm_compute::DataType::QSYMM8:
108             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
109             {
110                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
111                 break;
112             }
113             default:
114             {
115                 ARMNN_ASSERT_MSG(false, "Unknown data type");
116                 break;
117             }
118         }
119
120         m_RanOnce = true;
121     }
122 }
123
124 } //namespace armnn