IVGCVSW-3545 Update the device specs with the dynamic backend ids
authorMatteo Martincigh <matteo.martincigh@arm.com>
Thu, 15 Aug 2019 11:08:06 +0000 (12:08 +0100)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Mon, 19 Aug 2019 12:12:28 +0000 (12:12 +0000)
 * Now the utility function RegisterDynamicBackends returns a list of
   the backend ids that have been registered
 * The list of registered ids is added to the list of supported backends
   in the Runtime
 * Added unit tests

Change-Id: I97bbe1f680920358f5baba5a4666e4983b849cac
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
src/armnn/DeviceSpec.hpp
src/armnn/Runtime.cpp
src/backends/backendsCommon/DynamicBackendUtils.cpp
src/backends/backendsCommon/DynamicBackendUtils.hpp
src/backends/backendsCommon/test/DynamicBackendTests.hpp

index 35923e6..3226470 100644 (file)
@@ -24,9 +24,14 @@ public:
         return m_SupportedBackends;
     }
 
+    void AddSupportedBackends(const BackendIdSet& backendIds)
+    {
+        m_SupportedBackends.insert(backendIds.begin(), backendIds.end());
+    }
+
 private:
     DeviceSpec() = delete;
     BackendIdSet m_SupportedBackends;
 };
 
-}
+} // namespace armnn
index 6b91863..9e87484 100644 (file)
@@ -144,8 +144,7 @@ Runtime::Runtime(const CreationOptions& options)
     {
         // Store backend contexts for the supported ones
         const BackendIdSet& supportedBackends = m_DeviceSpec.GetSupportedBackends();
-        auto it = supportedBackends.find(id);
-        if (it != supportedBackends.end())
+        if (supportedBackends.find(id) != supportedBackends.end())
         {
             auto factoryFun = BackendRegistryInstance().GetFactory(id);
             auto backend = factoryFun();
@@ -257,7 +256,10 @@ void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
     m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
 
     // Register the dynamic backends in the backend registry
-    DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
+    BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
+
+    // Add the registered dynamic backend ids to the list of supported backends
+    m_DeviceSpec.AddSupportedBackends(registeredBackendIds);
 }
 
 } // namespace armnn
index fadec0c..fc4336f 100644 (file)
@@ -299,21 +299,25 @@ std::vector<DynamicBackendPtr> DynamicBackendUtils::CreateDynamicBackends(const
     return dynamicBackends;
 }
 
-void DynamicBackendUtils::RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends)
+BackendIdSet DynamicBackendUtils::RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends)
 {
     // Get a reference of the backend registry
     BackendRegistry& backendRegistry = BackendRegistryInstance();
 
-    // Register the dynamic backends in the backend registry
-    RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
+    // Register the dynamic backends in the backend registry, and return a list of registered backend ids
+    return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
 }
 
-void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
-                                                      const std::vector<DynamicBackendPtr>& dynamicBackends)
+BackendIdSet DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
+                                                              const std::vector<DynamicBackendPtr>& dynamicBackends)
 {
+    // Initialize the list of registered backend ids
+    BackendIdSet registeredBackendIds;
+
     // Register the dynamic backends in the backend registry
     for (const DynamicBackendPtr& dynamicBackend : dynamicBackends)
     {
+        // Get the id of the dynamic backend
         BackendId dynamicBackendId;
         try
         {
@@ -362,8 +366,22 @@ void DynamicBackendUtils::RegisterDynamicBackendsImpl(BackendRegistry& backendRe
         }
 
         // Register the dynamic backend
-        backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction);
+        try
+        {
+            backendRegistry.Register(dynamicBackendId, dynamicBackendFactoryFunction);
+        }
+        catch (const InvalidArgumentException& e)
+        {
+            BOOST_LOG_TRIVIAL(warning) << "An error has occurred when registering the dynamic backend \""
+                                       << dynamicBackendId << "\": " << e.what();
+            continue;
+        }
+
+        // Add the id of the dynamic backend just registered to the list of registered backend ids
+        registeredBackendIds.insert(dynamicBackendId);
     }
+
+    return registeredBackendIds;
 }
 
 } // namespace armnn
index 187b0b1..0aa0ac8 100644 (file)
@@ -39,14 +39,14 @@ public:
     static std::vector<std::string> GetSharedObjects(const std::vector<std::string>& backendPaths);
 
     static std::vector<DynamicBackendPtr> CreateDynamicBackends(const std::vector<std::string>& sharedObjects);
-    static void RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends);
+    static BackendIdSet RegisterDynamicBackends(const std::vector<DynamicBackendPtr>& dynamicBackends);
 
 protected:
     /// Protected methods for testing purposes
     static bool IsBackendCompatibleImpl(const BackendVersion& backendApiVersion, const BackendVersion& backendVersion);
     static std::vector<std::string> GetBackendPathsImpl(const std::string& backendPaths);
-    static void RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
-                                            const std::vector<DynamicBackendPtr>& dynamicBackends);
+    static BackendIdSet RegisterDynamicBackendsImpl(BackendRegistry& backendRegistry,
+                                                    const std::vector<DynamicBackendPtr>& dynamicBackends);
 
 private:
     static std::string GetDlError();
index 74ef6f1..e225124 100644 (file)
@@ -79,10 +79,11 @@ public:
         return GetBackendPathsImpl(path);
     }
 
-    static void RegisterDynamicBackendsImplTest(armnn::BackendRegistry& backendRegistry,
-                                                const std::vector<armnn::DynamicBackendPtr>& dynamicBackends)
+    static armnn::BackendIdSet RegisterDynamicBackendsImplTest(
+            armnn::BackendRegistry& backendRegistry,
+            const std::vector<armnn::DynamicBackendPtr>& dynamicBackends)
     {
-        RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
+        return RegisterDynamicBackendsImpl(backendRegistry, dynamicBackends);
     }
 };
 
@@ -896,12 +897,15 @@ void RegisterSingleDynamicBackendTestImpl()
     BackendVersion dynamicBackendVersion = dynamicBackends[0]->GetBackendVersion();
     BOOST_TEST(TestDynamicBackendUtils::IsBackendCompatible(dynamicBackendVersion));
 
-    TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+    BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+                                                                                                 dynamicBackends);
     BOOST_TEST(backendRegistry.Size() == 1);
+    BOOST_TEST(registeredBackendIds.size() == 1);
 
     BackendIdSet backendIds = backendRegistry.GetBackendIds();
     BOOST_TEST(backendIds.size() == 1);
     BOOST_TEST((backendIds.find(dynamicBackendId) != backendIds.end()));
+    BOOST_TEST((registeredBackendIds.find(dynamicBackendId) != registeredBackendIds.end()));
 
     auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(dynamicBackendId);
     BOOST_TEST((dynamicBackendFactoryFunction != nullptr));
@@ -960,14 +964,19 @@ void RegisterMultipleDynamicBackendsTestImpl()
     BackendRegistry backendRegistry;
     BOOST_TEST(backendRegistry.Size() == 0);
 
-    TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+    BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+                                                                                                 dynamicBackends);
     BOOST_TEST(backendRegistry.Size() == 3);
+    BOOST_TEST(registeredBackendIds.size() == 3);
 
     BackendIdSet backendIds = backendRegistry.GetBackendIds();
     BOOST_TEST(backendIds.size() == 3);
     BOOST_TEST((backendIds.find(dynamicBackendId1) != backendIds.end()));
     BOOST_TEST((backendIds.find(dynamicBackendId2) != backendIds.end()));
     BOOST_TEST((backendIds.find(dynamicBackendId3) != backendIds.end()));
+    BOOST_TEST((registeredBackendIds.find(dynamicBackendId1) != registeredBackendIds.end()));
+    BOOST_TEST((registeredBackendIds.find(dynamicBackendId2) != registeredBackendIds.end()));
+    BOOST_TEST((registeredBackendIds.find(dynamicBackendId3) != registeredBackendIds.end()));
 
     for (size_t i = 0; i < dynamicBackends.size(); i++)
     {
@@ -1036,8 +1045,10 @@ void RegisterMultipleInvalidDynamicBackendsTestImpl()
     BOOST_TEST(backendRegistry.Size() == 0);
 
     // Check that no dynamic backend got registered
-    TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+    BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+                                                                                                 dynamicBackends);
     BOOST_TEST(backendRegistry.Size() == 0);
+    BOOST_TEST(registeredBackendIds.empty());
 }
 
 void RegisterMixedDynamicBackendsTestImpl()
@@ -1165,14 +1176,17 @@ void RegisterMixedDynamicBackendsTestImpl()
         "TestValid5"
     };
 
-    TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry, dynamicBackends);
+    BackendIdSet registeredBackendIds = TestDynamicBackendUtils::RegisterDynamicBackendsImplTest(backendRegistry,
+                                                                                                 dynamicBackends);
     BOOST_TEST(backendRegistry.Size() == expectedRegisteredbackendIds.size());
+    BOOST_TEST(registeredBackendIds.size() == expectedRegisteredbackendIds.size());
 
     BackendIdSet backendIds = backendRegistry.GetBackendIds();
     BOOST_TEST(backendIds.size() == expectedRegisteredbackendIds.size());
     for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds)
     {
         BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end()));
+        BOOST_TEST((registeredBackendIds.find(expectedRegisteredbackendId) != registeredBackendIds.end()));
 
         auto dynamicBackendFactoryFunction = backendRegistry.GetFactory(expectedRegisteredbackendId);
         BOOST_TEST((dynamicBackendFactoryFunction != nullptr));
@@ -1190,10 +1204,16 @@ void RuntimeEmptyTestImpl()
     // Swapping the backend registry storage for testing
     TestBackendRegistry testBackendRegistry;
 
+    const BackendRegistry& backendRegistry = BackendRegistryInstance();
+    BOOST_TEST(backendRegistry.Size() == 0);
+
     IRuntime::CreationOptions creationOptions;
     IRuntimePtr runtime = IRuntime::Create(creationOptions);
 
-    const BackendRegistry& backendRegistry = BackendRegistryInstance();
+    const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+    BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+    BOOST_TEST(supportedBackendIds.empty());
+
     BOOST_TEST(backendRegistry.Size() == 0);
 }
 
@@ -1228,6 +1248,14 @@ void RuntimeDynamicBackendsTestImpl()
     {
         BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end()));
     }
+
+    const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+    BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+    BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size());
+    for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds)
+    {
+        BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end()));
+    }
 }
 
 void RuntimeDuplicateDynamicBackendsTestImpl()
@@ -1261,6 +1289,14 @@ void RuntimeDuplicateDynamicBackendsTestImpl()
     {
         BOOST_TEST((backendIds.find(expectedRegisteredbackendId) != backendIds.end()));
     }
+
+    const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+    BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+    BOOST_TEST(supportedBackendIds.size() == expectedRegisteredbackendIds.size());
+    for (const BackendId& expectedRegisteredbackendId : expectedRegisteredbackendIds)
+    {
+        BOOST_TEST((supportedBackendIds.find(expectedRegisteredbackendId) != supportedBackendIds.end()));
+    }
 }
 
 void RuntimeInvalidDynamicBackendsTestImpl()
@@ -1282,6 +1318,10 @@ void RuntimeInvalidDynamicBackendsTestImpl()
 
     const BackendRegistry& backendRegistry = BackendRegistryInstance();
     BOOST_TEST(backendRegistry.Size() == 0);
+
+    const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+    BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+    BOOST_TEST(supportedBackendIds.empty());
 }
 
 void RuntimeInvalidOverridePathTestImpl()
@@ -1298,6 +1338,10 @@ void RuntimeInvalidOverridePathTestImpl()
 
     const BackendRegistry& backendRegistry = BackendRegistryInstance();
     BOOST_TEST(backendRegistry.Size() == 0);
+
+    const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+    BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+    BOOST_TEST(supportedBackendIds.empty());
 }
 
 void CreateReferenceDynamicBackendTestImpl()
@@ -1330,6 +1374,11 @@ void CreateReferenceDynamicBackendTestImpl()
     BackendIdSet backendIds = backendRegistry.GetBackendIds();
     BOOST_TEST((backendIds.find("CpuRef") != backendIds.end()));
 
+    const DeviceSpec& deviceSpec = *boost::polymorphic_downcast<const DeviceSpec*>(&runtime->GetDeviceSpec());
+    BackendIdSet supportedBackendIds = deviceSpec.GetSupportedBackends();
+    BOOST_TEST(supportedBackendIds.size() == 1);
+    BOOST_TEST((supportedBackendIds.find("CpuRef") != supportedBackendIds.end()));
+
     // Get the factory function
     auto referenceDynamicBackendFactoryFunction = backendRegistry.GetFactory("CpuRef");
     BOOST_TEST((referenceDynamicBackendFactoryFunction != nullptr));