#pragma once
+#include <arm_compute/core/Types.h>
#include <backends/ArmComputeTensorUtils.hpp>
#include <backends/CpuTensorHandle.hpp>
#include <backends/NeonTensorHandle.hpp>
+#include <backends/NeonWorkloadUtils.hpp>
#include <backends/Workload.hpp>
+#include <Half.hpp>
#include <boost/cast.hpp>
+#include "Half.hpp"
namespace armnn
{
-// Base class template providing an implementation of the Constant layer common to all data types
-template <armnn::DataType DataFormat>
-class NeonBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataFormat>
+// Base class template providing an implementation of the Constant layer common to all data types.
+template <armnn::DataType... DataFormats>
+class NeonBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataFormats...>
{
public:
NeonBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
- : TypedWorkload<ConstantQueueDescriptor, DataFormat>(descriptor, info)
+ : TypedWorkload<ConstantQueueDescriptor, DataFormats...>(descriptor, info)
, m_RanOnce(false)
{
}
BOOST_ASSERT(data.m_LayerOutput != nullptr);
arm_compute::ITensor& output =
boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
+ arm_compute::DataType computeDataType =
+ boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
- switch (DataFormat)
+ switch (computeDataType)
{
- case DataType::Float32:
+ case arm_compute::DataType::F16:
+ {
+ CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
+ break;
+ }
+ case arm_compute::DataType::F32:
{
CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
break;
}
- case DataType::QuantisedAsymm8:
+ case arm_compute::DataType::QASYMM8:
{
CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
break;