IVGCVSW-2125 : backends now can return optimizations
authorDavid Beck <david.beck@arm.com>
Fri, 9 Nov 2018 14:46:40 +0000 (14:46 +0000)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Fri, 9 Nov 2018 16:44:01 +0000 (16:44 +0000)
Change-Id: Ieec34224b433e1d2f3bbe66632cd6016cac5498c

src/armnn/Network.cpp
src/backends/backendsCommon/IBackendInternal.hpp
src/backends/cl/ClBackend.cpp
src/backends/cl/ClBackend.hpp
src/backends/neon/NeonBackend.cpp
src/backends/neon/NeonBackend.hpp
src/backends/reference/RefBackend.cpp
src/backends/reference/RefBackend.hpp

index 43782e0..7b430c3 100644 (file)
@@ -11,6 +11,8 @@
 
 #include <backendsCommon/CpuTensorHandle.hpp>
 #include <backendsCommon/WorkloadFactory.hpp>
+#include <backendsCommon/BackendRegistry.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
 
 #include <armnn/Exceptions.hpp>
 #include <armnn/Utils.hpp>
@@ -169,6 +171,9 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
         return IOptimizedNetworkPtr(nullptr, &IOptimizedNetwork::Destroy);
     };
 
+    // The backends that we choose to run layers on
+    std::unordered_set<BackendId> chosenBackends;
+
     // Assign a compute device for all nodes
     bool bErrorFound = false;
     for (auto&& layer : optNetObjPtr->GetGraph())
@@ -275,6 +280,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
             else
             {
                 found = true;
+                chosenBackends.insert(backend);
                 break;
             }
         }
@@ -291,6 +297,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
                                 layerType == armnn::LayerType::Permute))
             {
                 layer->SetBackendId(armnn::Compute::CpuRef);
+                chosenBackends.insert(armnn::Compute::CpuRef);
             }
             else
             {
@@ -312,6 +319,20 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
     Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsFloatToHalf()));
     Optimizer::Pass(optNetObjPtr->GetGraph(), MakeOptimizations(ConvertConstantsHalfToFloat()));
 
+    // Run backend specific optimizations
+    for (auto&& chosenBackend : chosenBackends)
+    {
+        auto factoryFun = BackendRegistryInstance().GetFactory(chosenBackend);
+        auto backendPtr = factoryFun();
+        BOOST_ASSERT(backendPtr.get() != nullptr);
+
+        auto backendSpecificOptimizations = backendPtr->GetOptimizations();
+        if (!backendSpecificOptimizations.empty())
+        {
+            Optimizer::Pass(optNetObjPtr->GetGraph(), backendSpecificOptimizations);
+        }
+    }
+
     return optNet;
 }
 
index fede366..9c54b82 100644 (file)
@@ -6,11 +6,13 @@
 
 #include <armnn/Types.hpp>
 #include <armnn/IRuntime.hpp>
+#include <vector>
 
 namespace armnn
 {
 class IWorkloadFactory;
 class IBackendContext;
+class Optimization;
 
 class IBackendInternal : public IBackend
 {
@@ -26,9 +28,12 @@ public:
 
     using IWorkloadFactoryPtr = std::unique_ptr<IWorkloadFactory>;
     using IBackendContextPtr = std::unique_ptr<IBackendContext>;
+    using OptimizationPtr = std::unique_ptr<Optimization>;
+    using Optimizations = std::vector<OptimizationPtr>;
 
     virtual IWorkloadFactoryPtr CreateWorkloadFactory() const = 0;
     virtual IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const = 0;
+    virtual Optimizations GetOptimizations() const = 0;
 };
 
 using IBackendInternalUniquePtr = std::unique_ptr<IBackendInternal>;
index c07fa66..8209a10 100644 (file)
@@ -8,7 +8,9 @@
 #include "ClWorkloadFactory.hpp"
 #include "ClBackendContext.hpp"
 
+#include <backendsCommon/IBackendContext.hpp>
 #include <backendsCommon/BackendRegistry.hpp>
+#include <Optimizer.hpp>
 
 namespace armnn
 {
@@ -45,5 +47,9 @@ ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const
     return IBackendContextPtr{new ClBackendContext{options}};
 }
 
+IBackendInternal::Optimizations ClBackend::GetOptimizations() const
+{
+    return Optimizations{};
+}
 
 } // namespace armnn
index f8a6253..ad84e8a 100644 (file)
@@ -4,7 +4,6 @@
 //
 #pragma once
 
-#include <backendsCommon/IBackendContext.hpp>
 #include <backendsCommon/IBackendInternal.hpp>
 
 namespace armnn
@@ -21,6 +20,7 @@ public:
 
     IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override;
     IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
+    IBackendInternal::Optimizations GetOptimizations() const override;
 };
 
 } // namespace armnn
\ No newline at end of file
index 7058d24..9e079f3 100644 (file)
@@ -7,7 +7,9 @@
 #include "NeonBackendId.hpp"
 #include "NeonWorkloadFactory.hpp"
 
+#include <backendsCommon/IBackendContext.hpp>
 #include <backendsCommon/BackendRegistry.hpp>
+#include <Optimizer.hpp>
 
 #include <boost/cast.hpp>
 
@@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr NeonBackend::CreateWorkloadFactory() const
     return std::make_unique<NeonWorkloadFactory>();
 }
 
+IBackendInternal::IBackendContextPtr NeonBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
+{
+    return IBackendContextPtr{};
+}
+
+IBackendInternal::Optimizations NeonBackend::GetOptimizations() const
+{
+    return Optimizations{};
+}
+
 } // namespace armnn
\ No newline at end of file
index 9ee8b23..e0017d9 100644 (file)
@@ -4,7 +4,6 @@
 //
 #pragma once
 
-#include <backendsCommon/IBackendContext.hpp>
 #include <backendsCommon/IBackendInternal.hpp>
 
 namespace armnn
@@ -20,11 +19,8 @@ public:
     const BackendId& GetId() const override { return GetIdStatic(); }
 
     IWorkloadFactoryPtr CreateWorkloadFactory() const override;
-
-    IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
-    {
-        return IBackendContextPtr{};
-    }
+    IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
+    IBackendInternal::Optimizations GetOptimizations() const override;
 };
 
 } // namespace armnn
\ No newline at end of file
index b6fb0ff..2f5ec80 100644 (file)
@@ -7,7 +7,9 @@
 #include "RefBackendId.hpp"
 #include "RefWorkloadFactory.hpp"
 
+#include <backendsCommon/IBackendContext.hpp>
 #include <backendsCommon/BackendRegistry.hpp>
+#include <Optimizer.hpp>
 
 #include <boost/cast.hpp>
 
@@ -40,4 +42,14 @@ IBackendInternal::IWorkloadFactoryPtr RefBackend::CreateWorkloadFactory() const
     return std::make_unique<RefWorkloadFactory>();
 }
 
+IBackendInternal::IBackendContextPtr RefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
+{
+    return IBackendContextPtr{};
+}
+
+IBackendInternal::Optimizations RefBackend::GetOptimizations() const
+{
+    return Optimizations{};
+}
+
 } // namespace armnn
\ No newline at end of file
index 025a482..be71f35 100644 (file)
@@ -4,7 +4,6 @@
 //
 #pragma once
 
-#include <backendsCommon/IBackendContext.hpp>
 #include <backendsCommon/IBackendInternal.hpp>
 
 namespace armnn
@@ -20,11 +19,8 @@ public:
     const BackendId& GetId() const override { return GetIdStatic(); }
 
     IBackendInternal::IWorkloadFactoryPtr CreateWorkloadFactory() const override;
-
-    IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
-    {
-        return IBackendContextPtr{};
-    }
+    IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
+    IBackendInternal::Optimizations GetOptimizations() const override;
 };
 
 } // namespace armnn
\ No newline at end of file