2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/BackendId.hpp>
8 #include <armnn/Exceptions.hpp>
14 #include <unordered_map>
19 template <typename RegisteredType>
20 struct RegisteredTypeName
22 static const char * Name() { return "UNKNOWN"; }
25 template <typename RegisteredType, typename PointerType, typename ParamType>
29 using FactoryFunction = std::function<PointerType(const ParamType&)>;
31 void Register(const BackendId& id, FactoryFunction factory)
33 if (m_Factories.count(id) > 0)
35 throw InvalidArgumentException(
36 std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
40 m_Factories[id] = factory;
43 FactoryFunction GetFactory(const BackendId& id) const
45 auto it = m_Factories.find(id);
46 if (it == m_Factories.end())
48 throw InvalidArgumentException(
49 std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
56 FactoryFunction GetFactory(const BackendId& id,
57 FactoryFunction defaultFactory) const
59 auto it = m_Factories.find(id);
60 if (it == m_Factories.end())
62 return defaultFactory;
72 return m_Factories.size();
75 BackendIdSet GetBackendIds() const
78 for (const auto& it : m_Factories)
80 result.insert(it.first);
85 std::string GetBackendIdsAsString() const
87 static const std::string delimitator = ", ";
89 std::stringstream output;
90 for (auto& backendId : GetBackendIds())
92 if (output.tellp() != std::streampos(0))
94 output << delimitator;
103 virtual ~RegistryCommon() {}
106 using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;
109 static void Swap(RegistryCommon& instance, FactoryStorage& other)
111 std::swap(instance.m_Factories, other);
115 RegistryCommon(const RegistryCommon&) = delete;
116 RegistryCommon& operator=(const RegistryCommon&) = delete;
118 FactoryStorage m_Factories;
121 template <typename RegistryType>
122 struct StaticRegistryInitializer
124 using FactoryFunction = typename RegistryType::FactoryFunction;
126 StaticRegistryInitializer(RegistryType& instance,
128 FactoryFunction factory)
130 instance.Register(id, factory);