IVGCVSW-2071 : remove GetCompute() from the WorkloadFactory interface
authorDavid Beck <david.beck@arm.com>
Tue, 23 Oct 2018 15:09:36 +0000 (16:09 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Thu, 25 Oct 2018 08:49:58 +0000 (09:49 +0100)
Change-Id: I44a9d26d1a5d876d381aee4c6450af62811d0dbb

src/armnn/test/CreateWorkload.hpp
src/backends/WorkloadFactory.hpp
src/backends/cl/ClWorkloadFactory.cpp
src/backends/cl/ClWorkloadFactory.hpp
src/backends/neon/NeonWorkloadFactory.cpp
src/backends/neon/NeonWorkloadFactory.hpp
src/backends/reference/RefWorkloadFactory.cpp
src/backends/reference/RefWorkloadFactory.hpp
src/backends/test/NormTestImpl.hpp
src/backends/test/Pooling2dTestImpl.hpp

index aac0a4a..5308a1c 100644 (file)
@@ -32,7 +32,7 @@ std::unique_ptr<Workload> MakeAndCheckWorkload(Layer& layer, Graph& graph, const
     BOOST_TEST(workload.get() == boost::polymorphic_downcast<Workload*>(workload.get()),
                "Cannot convert to derived class");
     std::string reasonIfUnsupported;
-    layer.SetBackendId(factory.GetCompute());
+    layer.SetBackendId(factory.GetBackendId());
     BOOST_TEST(factory.IsLayerSupported(layer, layer.GetDataType(), reasonIfUnsupported));
     return std::unique_ptr<Workload>(static_cast<Workload*>(workload.release()));
 }
index 2d482e0..2f422ab 100644 (file)
@@ -21,7 +21,7 @@ class IWorkloadFactory
 public:
     virtual ~IWorkloadFactory() { }
 
-    virtual Compute GetCompute() const = 0;
+    virtual const BackendId& GetBackendId() const = 0;
 
     /// Informs the memory manager that the network is finalized and ready for execution.
     virtual void Finalize() { }
index c697d90..fd92db3 100644 (file)
@@ -3,6 +3,7 @@
 // SPDX-License-Identifier: MIT
 //
 #include "ClWorkloadFactory.hpp"
+#include "ClBackendId.hpp"
 
 #include <armnn/Exceptions.hpp>
 #include <armnn/Utils.hpp>
 namespace armnn
 {
 
+namespace
+{
+static const BackendId s_Id{ClBackendId()};
+}
+
 bool ClWorkloadFactory::IsLayerSupported(const Layer& layer,
                                          Optional<DataType> dataType,
                                          std::string& outReasonIfUnsupported)
 {
-    return IWorkloadFactory::IsLayerSupported(Compute::GpuAcc, layer, dataType, outReasonIfUnsupported);
+    return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
+}
+
+const BackendId& ClWorkloadFactory::GetBackendId() const
+{
+    return s_Id;
 }
 
 #ifdef ARMCOMPUTECL_ENABLED
index 1441b71..ba7cf69 100644 (file)
@@ -19,7 +19,7 @@ class ClWorkloadFactory : public IWorkloadFactory
 public:
     ClWorkloadFactory();
 
-    virtual Compute GetCompute() const override { return Compute::GpuAcc; }
+    const BackendId& GetBackendId() const override;
 
     static bool IsLayerSupported(const Layer& layer,
                                  Optional<DataType> dataType,
index f0a9e76..c16d383 100644 (file)
@@ -3,6 +3,7 @@
 // SPDX-License-Identifier: MIT
 //
 #include "NeonWorkloadFactory.hpp"
+#include "NeonBackendId.hpp"
 #include <armnn/Utils.hpp>
 #include <backends/CpuTensorHandle.hpp>
 #include <Layer.hpp>
 namespace armnn
 {
 
+namespace
+{
+static const BackendId s_Id{NeonBackendId()};
+}
+
 bool NeonWorkloadFactory::IsLayerSupported(const Layer& layer,
                                            Optional<DataType> dataType,
                                            std::string& outReasonIfUnsupported)
 {
-    return IWorkloadFactory::IsLayerSupported(Compute::CpuAcc, layer, dataType, outReasonIfUnsupported);
+    return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
+}
+
+const BackendId& NeonWorkloadFactory::GetBackendId() const
+{
+    return s_Id;
 }
 
 #ifdef ARMCOMPUTENEON_ENABLED
index d1dd2c8..030e982 100644 (file)
@@ -20,7 +20,7 @@ class NeonWorkloadFactory : public IWorkloadFactory
 public:
     NeonWorkloadFactory();
 
-    virtual Compute GetCompute() const override { return Compute::CpuAcc; }
+    const BackendId& GetBackendId() const override;
 
     static bool IsLayerSupported(const Layer& layer,
                                  Optional<DataType> dataType,
index 783e5fb..864ffdb 100644 (file)
@@ -6,6 +6,7 @@
 #include <backends/MemCopyWorkload.hpp>
 #include <backends/MakeWorkloadHelper.hpp>
 #include "RefWorkloadFactory.hpp"
+#include "RefBackendId.hpp"
 #include "workloads/RefWorkloads.hpp"
 #include "Layer.hpp"
 
 namespace armnn
 {
 
+namespace
+{
+static const BackendId s_Id{RefBackendId()};
+}
+
 template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
 std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
     const WorkloadInfo& info) const
@@ -25,11 +31,16 @@ RefWorkloadFactory::RefWorkloadFactory()
 {
 }
 
+const BackendId& RefWorkloadFactory::GetBackendId() const
+{
+    return s_Id;
+}
+
 bool RefWorkloadFactory::IsLayerSupported(const Layer& layer,
                                           Optional<DataType> dataType,
                                           std::string& outReasonIfUnsupported)
 {
-    return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported);
+    return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
 }
 
 std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
index ef2e1ab..be0dafc 100644 (file)
@@ -33,7 +33,7 @@ public:
     explicit RefWorkloadFactory();
     virtual ~RefWorkloadFactory() {}
 
-    virtual Compute GetCompute() const override { return Compute::CpuRef; }
+    const BackendId& GetBackendId() const override;
 
     static bool IsLayerSupported(const Layer& layer,
                                  Optional<DataType> dataType,
index f4e6aea..de954b9 100644 (file)
@@ -308,10 +308,10 @@ LayerTestResult<float,4> CompareNormalizationTestImpl(armnn::IWorkloadFactory& w
     SetWorkloadOutput(refData, refInfo, 0, outputTensorInfo, outputHandleRef.get());
 
     // Don't execute if Normalization is not supported for the method and channel types, as an exception will be raised.
-    armnn::Compute compute = workloadFactory.GetCompute();
+    armnn::BackendId backend = workloadFactory.GetBackendId();
     const size_t reasonIfUnsupportedMaxLen = 255;
     char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1];
-    ret.supported = armnn::IsNormalizationSupported(compute, inputTensorInfo, outputTensorInfo, data.m_Parameters,
+    ret.supported = armnn::IsNormalizationSupported(backend, inputTensorInfo, outputTensorInfo, data.m_Parameters,
                                                     reasonIfUnsupported, reasonIfUnsupportedMaxLen);
     if (!ret.supported)
     {
index 29263af..90be289 100644 (file)
@@ -77,10 +77,10 @@ LayerTestResult<T, 4> SimplePooling2dTestImpl(armnn::IWorkloadFactory& workloadF
     AddOutputToWorkload(queueDescriptor, workloadInfo, outputTensorInfo, outputHandle.get());
 
     // Don't execute if Pooling is not supported, as an exception will be raised.
-    armnn::Compute compute = workloadFactory.GetCompute();
+    armnn::BackendId backend = workloadFactory.GetBackendId();
     const size_t reasonIfUnsupportedMaxLen = 255;
     char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1];
-    result.supported = armnn::IsPooling2dSupported(compute, inputTensorInfo, outputTensorInfo,
+    result.supported = armnn::IsPooling2dSupported(backend, inputTensorInfo, outputTensorInfo,
                                                    queueDescriptor.m_Parameters,
                                                    reasonIfUnsupported, reasonIfUnsupportedMaxLen);
     if (!result.supported)
@@ -650,10 +650,10 @@ LayerTestResult<T, 4> ComparePooling2dTestCommon(armnn::IWorkloadFactory& worklo
     std::unique_ptr<armnn::ITensorHandle> inputHandleRef = refWorkloadFactory.CreateTensorHandle(inputTensorInfo);
 
     // Don't execute if Pooling is not supported, as an exception will be raised.
-    armnn::Compute compute = workloadFactory.GetCompute();
+    armnn::BackendId backend = workloadFactory.GetBackendId();
     const size_t reasonIfUnsupportedMaxLen = 255;
     char reasonIfUnsupported[reasonIfUnsupportedMaxLen+1];
-    comparisonResult.supported = armnn::IsPooling2dSupported(compute, inputTensorInfo, outputTensorInfo,
+    comparisonResult.supported = armnn::IsPooling2dSupported(backend, inputTensorInfo, outputTensorInfo,
                                                              data.m_Parameters,
                                                              reasonIfUnsupported, reasonIfUnsupportedMaxLen);
     if (!comparisonResult.supported)