IVGCVSW-1843 : refactor ClAdditionWorkload and ClSubtractionWorkload
authorDavid Beck <david.beck@arm.com>
Tue, 11 Sep 2018 14:21:14 +0000 (15:21 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Mon, 1 Oct 2018 13:56:47 +0000 (14:56 +0100)
Change-Id: I0ca9f16217f8e32bb57a49b841611f10dabf021a

22 files changed:
Android.mk
CMakeLists.txt
src/armnn/backends/ClLayerSupport.cpp
src/armnn/backends/ClWorkloadFactory.cpp
src/armnn/backends/ClWorkloads.hpp
src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp [deleted file]
src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp [deleted file]
src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp [deleted file]
src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp [deleted file]
src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp [deleted file]
src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp [deleted file]
src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp [new file with mode: 0644]
src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp [new file with mode: 0644]
src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp [deleted file]
src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp [deleted file]
src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp [deleted file]
src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp [deleted file]
src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp [deleted file]
src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp [deleted file]
src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp [new file with mode: 0644]
src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp [new file with mode: 0644]
src/armnn/backends/test/CreateWorkloadCl.cpp

index 9c4db74d1a5afdeab15fcf6405f54fe87cd3a75a..6f7771c73ceaa6f748cb9812a92eb2bb5a9ed1ce 100644 (file)
@@ -45,12 +45,8 @@ LOCAL_SRC_FILES := \
         src/armnn/backends/ArmComputeTensorUtils.cpp \
         src/armnn/backends/ClWorkloads/ClActivationFloatWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClActivationUint8Workload.cpp \
-        src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp \
-        src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp \
-        src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp \
-        src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp \
-        src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp \
-        src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp \
+        src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp \
+        src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClBatchNormalizationFloatWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClConstantFloatWorkload.cpp \
index 9c2685c96d886f3ca1b7b3b4ab8b2ab67525c0c5..d166a718fcd69082f9cfa80a32ecef34f80bcc73 100644 (file)
@@ -483,18 +483,10 @@ if(ARMCOMPUTECL)
         src/armnn/backends/ClWorkloads/ClActivationFloatWorkload.hpp
         src/armnn/backends/ClWorkloads/ClActivationUint8Workload.cpp
         src/armnn/backends/ClWorkloads/ClActivationUint8Workload.hpp
-        src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp
-        src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp
-        src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp
-        src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp
-        src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp
-        src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp
-        src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp
-        src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp
-        src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp
-        src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp
-        src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp
-        src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp
+        src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp
+        src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp
+        src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp
+        src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp
         src/armnn/backends/ClWorkloads/ClConvertFp16ToFp32Workload.cpp
         src/armnn/backends/ClWorkloads/ClConvertFp16ToFp32Workload.hpp
         src/armnn/backends/ClWorkloads/ClConvertFp32ToFp16Workload.cpp
index 3dba1ec94cdd6488ab8a5dd7872bc4a3264eef0d..aeb2759aa1e450491e0cf4cbf4b1d6ee4e4eaf06 100644 (file)
@@ -14,7 +14,7 @@
 #include <boost/core/ignore_unused.hpp>
 
 #ifdef ARMCOMPUTECL_ENABLED
-#include "ClWorkloads/ClAdditionFloatWorkload.hpp"
+#include "ClWorkloads/ClAdditionWorkload.hpp"
 #include "ClWorkloads/ClActivationFloatWorkload.hpp"
 #include "ClWorkloads/ClBatchNormalizationFloatWorkload.hpp"
 #include "ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
@@ -29,7 +29,7 @@
 #include "ClWorkloads/ClPermuteWorkload.hpp"
 #include "ClWorkloads/ClNormalizationFloatWorkload.hpp"
 #include "ClWorkloads/ClSoftmaxBaseWorkload.hpp"
-#include "ClWorkloads/ClSubtractionFloatWorkload.hpp"
+#include "ClWorkloads/ClSubtractionWorkload.hpp"
 #include "ClWorkloads/ClLstmFloatWorkload.hpp"
 #endif
 
index 056a201783a6b2f8c371c3ec2a4d804d01af92c6..217c63778476386045567ca6ed28483feb3705d8 100644 (file)
@@ -154,7 +154,8 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateNormalization(const N
 std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
                                                                     const WorkloadInfo&            info) const
 {
-    return MakeWorkload<ClAdditionFloatWorkload, ClAdditionUint8Workload>(descriptor, info);
+    return MakeWorkload<ClAdditionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>,
+                        ClAdditionWorkload<armnn::DataType::QuantisedAsymm8>>(descriptor, info);
 }
 
 std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMultiplication(
@@ -172,7 +173,8 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision(
 std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
                                                                        const WorkloadInfo& info) const
 {
-    return MakeWorkload<ClSubtractionFloatWorkload, ClSubtractionUint8Workload>(descriptor, info);
+    return MakeWorkload<ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>,
+                        ClSubtractionWorkload<armnn::DataType::QuantisedAsymm8>>(descriptor, info);
 }
 
 std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateBatchNormalization(
index 0800401a2256b1cf53d1e06b49656e86e9e03a52..3472bca45c644fac7828afca6e0f45240b89fdfe 100644 (file)
@@ -6,8 +6,7 @@
 #pragma once
 #include "backends/ClWorkloads/ClActivationFloatWorkload.hpp"
 #include "backends/ClWorkloads/ClActivationUint8Workload.hpp"
-#include "backends/ClWorkloads/ClAdditionFloatWorkload.hpp"
-#include "backends/ClWorkloads/ClAdditionUint8Workload.hpp"
+#include "backends/ClWorkloads/ClAdditionWorkload.hpp"
 #include "backends/ClWorkloads/ClBaseConstantWorkload.hpp"
 #include "backends/ClWorkloads/ClBaseMergerWorkload.hpp"
 #include "backends/ClWorkloads/ClBatchNormalizationFloatWorkload.hpp"
@@ -36,7 +35,6 @@
 #include "backends/ClWorkloads/ClSoftmaxUint8Workload.hpp"
 #include "backends/ClWorkloads/ClSplitterFloatWorkload.hpp"
 #include "backends/ClWorkloads/ClSplitterUint8Workload.hpp"
-#include "backends/ClWorkloads/ClSubtractionFloatWorkload.hpp"
-#include "backends/ClWorkloads/ClSubtractionUint8Workload.hpp"
+#include "backends/ClWorkloads/ClSubtractionWorkload.hpp"
 #include "backends/ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
 #include "backends/ClWorkloads/ClConvertFp32ToFp16Workload.hpp"
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp
deleted file mode 100644 (file)
index eb14aa3..0000000
+++ /dev/null
@@ -1,64 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClAdditionBaseWorkload.hpp"
-
-#include "backends/ClTensorHandle.hpp"
-#include "backends/CpuTensorHandle.hpp"
-#include "backends/ArmComputeTensorUtils.hpp"
-
-namespace armnn
-{
-using namespace armcomputetensorutils;
-
-static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE;
-
-template <armnn::DataType... T>
-ClAdditionBaseWorkload<T...>::ClAdditionBaseWorkload(const AdditionQueueDescriptor& descriptor,
-                                                  const WorkloadInfo& info)
-    : TypedWorkload<AdditionQueueDescriptor, T...>(descriptor, info)
-{
-    this->m_Data.ValidateInputsOutputs("ClAdditionBaseWorkload", 2, 1);
-
-    arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
-    arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[1])->GetTensor();
-    arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
-    m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy);
-}
-
-template <armnn::DataType... T>
-void ClAdditionBaseWorkload<T...>::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionBaseWorkload_Execute");
-    m_Layer.run();
-}
-
-bool ClAdditionValidate(const TensorInfo& input0,
-                        const TensorInfo& input1,
-                        const TensorInfo& output,
-                        std::string* reasonIfUnsupported)
-{
-    const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0);
-    const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1);
-    const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
-
-    const arm_compute::Status aclStatus = arm_compute::CLArithmeticAddition::validate(&aclInput0Info,
-                                                                                      &aclInput1Info,
-                                                                                      &aclOutputInfo,
-                                                                                      g_AclConvertPolicy);
-
-    const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
-    if (!supported && reasonIfUnsupported)
-    {
-        *reasonIfUnsupported = aclStatus.error_description();
-    }
-
-    return supported;
-}
-
-} //namespace armnn
-
-template class armnn::ClAdditionBaseWorkload<armnn::DataType::Float16, armnn::DataType::Float32>;
-template class armnn::ClAdditionBaseWorkload<armnn::DataType::QuantisedAsymm8>;
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp
deleted file mode 100644 (file)
index b3bf1fe..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "backends/ClWorkloadUtils.hpp"
-
-namespace armnn
-{
-
-template <armnn::DataType... dataTypes>
-class ClAdditionBaseWorkload : public TypedWorkload<AdditionQueueDescriptor, dataTypes...>
-{
-public:
-    ClAdditionBaseWorkload(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info);
-
-    void Execute() const override;
-
-private:
-    mutable arm_compute::CLArithmeticAddition m_Layer;
-};
-
-bool ClAdditionValidate(const TensorInfo& input0,
-                        const TensorInfo& input1,
-                        const TensorInfo& output,
-                        std::string* reasonIfUnsupported);
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp
deleted file mode 100644 (file)
index b51d8a7..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClAdditionFloatWorkload.hpp"
-
-#include "backends/ClTensorHandle.hpp"
-#include "backends/CpuTensorHandle.hpp"
-#include "backends/ArmComputeTensorUtils.hpp"
-
-namespace armnn
-{
-using namespace armcomputetensorutils;
-
-void ClAdditionFloatWorkload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionFloatWorkload_Execute");
-    ClAdditionBaseWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp
deleted file mode 100644 (file)
index de33ca6..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "ClAdditionBaseWorkload.hpp"
-
-namespace armnn
-{
-
-class ClAdditionFloatWorkload : public ClAdditionBaseWorkload<DataType::Float16, DataType::Float32>
-{
-public:
-    using ClAdditionBaseWorkload<DataType::Float16, DataType::Float32>::ClAdditionBaseWorkload;
-    void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp
deleted file mode 100644 (file)
index 57b9062..0000000
+++ /dev/null
@@ -1,18 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClAdditionUint8Workload.hpp"
-
-namespace armnn
-{
-using namespace armcomputetensorutils;
-
-void ClAdditionUint8Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionUint8Workload_Execute");
-    ClAdditionBaseWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp
deleted file mode 100644 (file)
index d127e7e..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "ClAdditionBaseWorkload.hpp"
-
-namespace armnn
-{
-
-class ClAdditionUint8Workload : public ClAdditionBaseWorkload<DataType::QuantisedAsymm8>
-{
-public:
-    using ClAdditionBaseWorkload<DataType::QuantisedAsymm8>::ClAdditionBaseWorkload;
-    void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp
new file mode 100644 (file)
index 0000000..0bba327
--- /dev/null
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClAdditionWorkload.hpp"
+
+#include "backends/ClTensorHandle.hpp"
+#include "backends/CpuTensorHandle.hpp"
+#include "backends/ArmComputeTensorUtils.hpp"
+
+namespace armnn
+{
+using namespace armcomputetensorutils;
+
+static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE;
+
+template <armnn::DataType... T>
+ClAdditionWorkload<T...>::ClAdditionWorkload(const AdditionQueueDescriptor& descriptor,
+                                                  const WorkloadInfo& info)
+    : TypedWorkload<AdditionQueueDescriptor, T...>(descriptor, info)
+{
+    this->m_Data.ValidateInputsOutputs("ClAdditionWorkload", 2, 1);
+
+    arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
+    arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[1])->GetTensor();
+    arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
+    m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy);
+}
+
+template <armnn::DataType... T>
+void ClAdditionWorkload<T...>::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionWorkload_Execute");
+    m_Layer.run();
+}
+
+bool ClAdditionValidate(const TensorInfo& input0,
+                        const TensorInfo& input1,
+                        const TensorInfo& output,
+                        std::string* reasonIfUnsupported)
+{
+    const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0);
+    const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1);
+    const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
+
+    const arm_compute::Status aclStatus = arm_compute::CLArithmeticAddition::validate(&aclInput0Info,
+                                                                                      &aclInput1Info,
+                                                                                      &aclOutputInfo,
+                                                                                      g_AclConvertPolicy);
+
+    const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
+    if (!supported && reasonIfUnsupported)
+    {
+        *reasonIfUnsupported = aclStatus.error_description();
+    }
+
+    return supported;
+}
+
+} //namespace armnn
+
+template class armnn::ClAdditionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>;
+template class armnn::ClAdditionWorkload<armnn::DataType::QuantisedAsymm8>;
diff --git a/src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp
new file mode 100644 (file)
index 0000000..8af8f23
--- /dev/null
@@ -0,0 +1,29 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backends/ClWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+template <armnn::DataType... dataTypes>
+class ClAdditionWorkload : public TypedWorkload<AdditionQueueDescriptor, dataTypes...>
+{
+public:
+    ClAdditionWorkload(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info);
+
+    void Execute() const override;
+
+private:
+    mutable arm_compute::CLArithmeticAddition m_Layer;
+};
+
+bool ClAdditionValidate(const TensorInfo& input0,
+                        const TensorInfo& input1,
+                        const TensorInfo& output,
+                        std::string* reasonIfUnsupported);
+} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp
deleted file mode 100644 (file)
index 2145ed4..0000000
+++ /dev/null
@@ -1,64 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClSubtractionBaseWorkload.hpp"
-
-#include "backends/ClTensorHandle.hpp"
-#include "backends/CpuTensorHandle.hpp"
-#include "backends/ArmComputeTensorUtils.hpp"
-
-namespace armnn
-{
-using namespace armcomputetensorutils;
-
-static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE;
-
-template <armnn::DataType... T>
-ClSubtractionBaseWorkload<T...>::ClSubtractionBaseWorkload(const SubtractionQueueDescriptor& descriptor,
-                                                           const WorkloadInfo& info)
-    : TypedWorkload<SubtractionQueueDescriptor, T...>(descriptor, info)
-{
-    this->m_Data.ValidateInputsOutputs("ClSubtractionBaseWorkload", 2, 1);
-
-    arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
-    arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[1])->GetTensor();
-    arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
-    m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy);
-}
-
-template <armnn::DataType... T>
-void ClSubtractionBaseWorkload<T...>::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionBaseWorkload_Execute");
-    m_Layer.run();
-}
-
-bool ClSubtractionValidate(const TensorInfo& input0,
-                           const TensorInfo& input1,
-                           const TensorInfo& output,
-                           std::string* reasonIfUnsupported)
-{
-    const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0);
-    const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1);
-    const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
-
-    const arm_compute::Status aclStatus = arm_compute::CLArithmeticSubtraction::validate(&aclInput0Info,
-                                                                                         &aclInput1Info,
-                                                                                         &aclOutputInfo,
-                                                                                         g_AclConvertPolicy);
-
-    const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
-    if (!supported && reasonIfUnsupported)
-    {
-        *reasonIfUnsupported = aclStatus.error_description();
-    }
-
-    return supported;
-}
-
-} //namespace armnn
-
-template class armnn::ClSubtractionBaseWorkload<armnn::DataType::Float16, armnn::DataType::Float32>;
-template class armnn::ClSubtractionBaseWorkload<armnn::DataType::QuantisedAsymm8>;
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp
deleted file mode 100644 (file)
index e4595d4..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "backends/ClWorkloadUtils.hpp"
-
-namespace armnn
-{
-
-template <armnn::DataType... dataTypes>
-class ClSubtractionBaseWorkload : public TypedWorkload<SubtractionQueueDescriptor, dataTypes...>
-{
-public:
-    ClSubtractionBaseWorkload(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info);
-
-    void Execute() const override;
-
-private:
-    mutable arm_compute::CLArithmeticSubtraction m_Layer;
-};
-
-bool ClSubtractionValidate(const TensorInfo& input0,
-                           const TensorInfo& input1,
-                           const TensorInfo& output,
-                           std::string* reasonIfUnsupported);
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp
deleted file mode 100644 (file)
index 3321e20..0000000
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClSubtractionFloatWorkload.hpp"
-
-#include "backends/ClTensorHandle.hpp"
-#include "backends/CpuTensorHandle.hpp"
-#include "backends/ArmComputeTensorUtils.hpp"
-
-namespace armnn
-{
-using namespace armcomputetensorutils;
-
-void ClSubtractionFloatWorkload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionFloatWorkload_Execute");
-    ClSubtractionBaseWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp
deleted file mode 100644 (file)
index 34a5e40..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "ClSubtractionBaseWorkload.hpp"
-
-namespace armnn
-{
-
-class ClSubtractionFloatWorkload : public ClSubtractionBaseWorkload<DataType::Float16, DataType::Float32>
-{
-public:
-    using ClSubtractionBaseWorkload<DataType::Float16, DataType::Float32>::ClSubtractionBaseWorkload;
-    void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp
deleted file mode 100644 (file)
index 966068d..0000000
+++ /dev/null
@@ -1,18 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "ClSubtractionUint8Workload.hpp"
-
-namespace armnn
-{
-using namespace armcomputetensorutils;
-
-void ClSubtractionUint8Workload::Execute() const
-{
-    ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionUint8Workload_Execute");
-    ClSubtractionBaseWorkload::Execute();
-}
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp
deleted file mode 100644 (file)
index 15b2059..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "ClSubtractionBaseWorkload.hpp"
-
-namespace armnn
-{
-
-class ClSubtractionUint8Workload : public ClSubtractionBaseWorkload<DataType::QuantisedAsymm8>
-{
-public:
-    using ClSubtractionBaseWorkload<DataType::QuantisedAsymm8>::ClSubtractionBaseWorkload;
-    void Execute() const override;
-};
-
-} //namespace armnn
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp
new file mode 100644 (file)
index 0000000..ec8bfc6
--- /dev/null
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClSubtractionWorkload.hpp"
+
+#include "backends/ClTensorHandle.hpp"
+#include "backends/CpuTensorHandle.hpp"
+#include "backends/ArmComputeTensorUtils.hpp"
+
+namespace armnn
+{
+using namespace armcomputetensorutils;
+
+static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE;
+
+template <armnn::DataType... T>
+ClSubtractionWorkload<T...>::ClSubtractionWorkload(const SubtractionQueueDescriptor& descriptor,
+                                                           const WorkloadInfo& info)
+    : TypedWorkload<SubtractionQueueDescriptor, T...>(descriptor, info)
+{
+    this->m_Data.ValidateInputsOutputs("ClSubtractionWorkload", 2, 1);
+
+    arm_compute::ICLTensor& input0 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
+    arm_compute::ICLTensor& input1 = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[1])->GetTensor();
+    arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
+    m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy);
+}
+
+template <armnn::DataType... T>
+void ClSubtractionWorkload<T...>::Execute() const
+{
+    ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionWorkload_Execute");
+    m_Layer.run();
+}
+
+bool ClSubtractionValidate(const TensorInfo& input0,
+                           const TensorInfo& input1,
+                           const TensorInfo& output,
+                           std::string* reasonIfUnsupported)
+{
+    const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0);
+    const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1);
+    const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
+
+    const arm_compute::Status aclStatus = arm_compute::CLArithmeticSubtraction::validate(&aclInput0Info,
+                                                                                         &aclInput1Info,
+                                                                                         &aclOutputInfo,
+                                                                                         g_AclConvertPolicy);
+
+    const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
+    if (!supported && reasonIfUnsupported)
+    {
+        *reasonIfUnsupported = aclStatus.error_description();
+    }
+
+    return supported;
+}
+
+} //namespace armnn
+
+template class armnn::ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>;
+template class armnn::ClSubtractionWorkload<armnn::DataType::QuantisedAsymm8>;
diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp
new file mode 100644 (file)
index 0000000..422e6a7
--- /dev/null
@@ -0,0 +1,29 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "backends/ClWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+template <armnn::DataType... dataTypes>
+class ClSubtractionWorkload : public TypedWorkload<SubtractionQueueDescriptor, dataTypes...>
+{
+public:
+    ClSubtractionWorkload(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info);
+
+    void Execute() const override;
+
+private:
+    mutable arm_compute::CLArithmeticSubtraction m_Layer;
+};
+
+bool ClSubtractionValidate(const TensorInfo& input0,
+                           const TensorInfo& input1,
+                           const TensorInfo& output,
+                           std::string* reasonIfUnsupported);
+} //namespace armnn
index 340279e61915bf10085a51a8de5fed8fce06c272..23843bd09567becbb92a6f320cdae7f2e2e7831d 100644 (file)
@@ -69,7 +69,7 @@ static void ClCreateArithmethicWorkloadTest()
 
 BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
 {
-    ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload,
+    ClCreateArithmethicWorkloadTest<ClAdditionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>,
                                     AdditionQueueDescriptor,
                                     AdditionLayer,
                                     armnn::DataType::Float32>();
@@ -77,7 +77,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload)
 
 BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload)
 {
-    ClCreateArithmethicWorkloadTest<ClAdditionFloatWorkload,
+    ClCreateArithmethicWorkloadTest<ClAdditionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>,
                                     AdditionQueueDescriptor,
                                     AdditionLayer,
                                     armnn::DataType::Float16>();
@@ -85,7 +85,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload)
 
 BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
 {
-    ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload,
+    ClCreateArithmethicWorkloadTest<ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>,
                                     SubtractionQueueDescriptor,
                                     SubtractionLayer,
                                     armnn::DataType::Float32>();
@@ -93,7 +93,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload)
 
 BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload)
 {
-    ClCreateArithmethicWorkloadTest<ClSubtractionFloatWorkload,
+    ClCreateArithmethicWorkloadTest<ClSubtractionWorkload<armnn::DataType::Float16, armnn::DataType::Float32>,
                                     SubtractionQueueDescriptor,
                                     SubtractionLayer,
                                     armnn::DataType::Float16>();