Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloadUtils.hpp
index 549a0bb..6b6a18e 100644 (file)
@@ -9,6 +9,15 @@
 #include <arm_compute/runtime/CL/CLFunctions.h>
 #include <arm_compute/runtime/SubTensor.h>
 #include "ArmComputeTensorUtils.hpp"
+#include "OpenClTimer.hpp"
+#include "CpuTensorHandle.hpp"
+#include "Half.hpp"
+
+#define ARMNN_SCOPED_PROFILING_EVENT_CL(name) \
+    ARMNN_SCOPED_PROFILING_EVENT_WITH_INSTRUMENTS(armnn::Compute::GpuAcc, \
+                                                  name, \
+                                                  armnn::OpenClTimer(), \
+                                                  armnn::WallClockTimer())
 
 namespace armnn
 {
@@ -17,12 +26,12 @@ template <typename T>
 void CopyArmComputeClTensorData(const T* srcData, arm_compute::CLTensor& dstTensor)
 {
     {
-        ARMNN_SCOPED_PROFILING_EVENT(Compute::GpuAcc, "MapClTensorForWriting");
+        ARMNN_SCOPED_PROFILING_EVENT_CL("MapClTensorForWriting");
         dstTensor.map(true);
     }
 
     {
-        ARMNN_SCOPED_PROFILING_EVENT(Compute::GpuAcc, "CopyToClTensor");
+        ARMNN_SCOPED_PROFILING_EVENT_CL("CopyToClTensor");
         armcomputetensorutils::CopyArmComputeITensorData<T>(srcData, dstTensor);
     }
 
@@ -36,4 +45,21 @@ void InitialiseArmComputeClTensorData(arm_compute::CLTensor& clTensor, const T*
     CopyArmComputeClTensorData<T>(data, clTensor);
 }
 
+inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& clTensor,
+                                                          const ConstCpuTensorHandle *handle)
+{
+    BOOST_ASSERT(handle);
+    switch(handle->GetTensorInfo().GetDataType())
+    {
+        case DataType::Float16:
+            InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<armnn::Half>());
+            break;
+        case DataType::Float32:
+            InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
+            break;
+        default:
+            BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
+    }
+};
+
 } //namespace armnn