Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / NeonWorkloads / NeonBaseConstantWorkload.hpp
index 247ebfc..e0ad408 100644 (file)
@@ -5,23 +5,27 @@
 
 #pragma once
 
+#include <arm_compute/core/Types.h>
 #include <backends/ArmComputeTensorUtils.hpp>
 #include <backends/CpuTensorHandle.hpp>
 #include <backends/NeonTensorHandle.hpp>
+#include <backends/NeonWorkloadUtils.hpp>
 #include <backends/Workload.hpp>
+#include <Half.hpp>
 
 #include <boost/cast.hpp>
+#include "Half.hpp"
 
 namespace armnn
 {
 
-// Base class template providing an implementation of the Constant layer common to all data types
-template <armnn::DataType DataFormat>
-class NeonBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataFormat>
+// Base class template providing an implementation of the Constant layer common to all data types.
+template <armnn::DataType... DataFormats>
+class NeonBaseConstantWorkload : public TypedWorkload<ConstantQueueDescriptor, DataFormats...>
 {
 public:
     NeonBaseConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
-        : TypedWorkload<ConstantQueueDescriptor, DataFormat>(descriptor, info)
+        : TypedWorkload<ConstantQueueDescriptor, DataFormats...>(descriptor, info)
         , m_RanOnce(false)
     {
     }
@@ -41,15 +45,22 @@ public:
             BOOST_ASSERT(data.m_LayerOutput != nullptr);
             arm_compute::ITensor& output =
                 boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
+            arm_compute::DataType computeDataType =
+                boost::polymorphic_downcast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
 
-            switch (DataFormat)
+            switch (computeDataType)
             {
-                case DataType::Float32:
+                case arm_compute::DataType::F16:
+                {
+                    CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
+                    break;
+                }
+                case arm_compute::DataType::F32:
                 {
                     CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
                     break;
                 }
-                case DataType::QuantisedAsymm8:
+                case arm_compute::DataType::QASYMM8:
                 {
                     CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
                     break;