2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "NeonConstantWorkload.hpp"
8 #include <arm_compute/core/Types.h>
9 #include <BFloat16.hpp>
11 #include <aclCommon/ArmComputeTensorUtils.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 #include <neon/NeonTensorHandle.hpp>
14 #include <backendsCommon/CpuTensorHandle.hpp>
15 #include <backendsCommon/Workload.hpp>
17 #include <boost/cast.hpp>
22 arm_compute::Status NeonConstantWorkloadValidate(const TensorInfo& output)
24 const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
26 std::array<arm_compute::DataType,8> supportedTypes = {
27 arm_compute::DataType::BFLOAT16,
28 arm_compute::DataType::F16,
29 arm_compute::DataType::F32,
30 arm_compute::DataType::QASYMM8,
31 arm_compute::DataType::QASYMM8_SIGNED,
32 arm_compute::DataType::QSYMM16,
33 arm_compute::DataType::QSYMM8,
34 arm_compute::DataType::QSYMM8_PER_CHANNEL
36 auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
38 if (it != end(supportedTypes))
40 return arm_compute::Status{};
44 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
48 NeonConstantWorkload::NeonConstantWorkload(const ConstantQueueDescriptor& descriptor,
49 const WorkloadInfo& info)
50 : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
55 void NeonConstantWorkload::Execute() const
57 ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonConstantWorkload_Execute");
59 using namespace armcomputetensorutils;
61 // The intermediate tensor held by the corresponding layer output handler can be initialised with the
62 // given data on the first inference, then reused for subsequent inferences.
63 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
64 // may not have been configured at the time.
67 const ConstantQueueDescriptor& data = this->m_Data;
69 ARMNN_ASSERT(data.m_LayerOutput != nullptr);
70 arm_compute::ITensor& output =
71 PolymorphicDowncast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
72 arm_compute::DataType computeDataType =
73 PolymorphicDowncast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
75 switch (computeDataType)
77 case arm_compute::DataType::BFLOAT16:
79 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<BFloat16>(), output);
82 case arm_compute::DataType::F16:
84 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
87 case arm_compute::DataType::F32:
89 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
92 case arm_compute::DataType::QASYMM8:
94 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
97 case arm_compute::DataType::QASYMM8_SIGNED:
99 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
102 case arm_compute::DataType::QSYMM16:
104 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int16_t>(), output);
107 case arm_compute::DataType::QSYMM8:
108 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
110 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
115 ARMNN_ASSERT_MSG(false, "Unknown data type");