2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <armnn/BackendId.hpp>
8 #include <armnn/Exceptions.hpp>
13 #include <unordered_map>
18 template <typename RegisteredType>
19 struct RegisteredTypeName
21 static const char * Name() { return "UNKNOWN"; }
24 template <typename RegisteredType, typename PointerType, typename ParamType>
28 using FactoryFunction = std::function<PointerType(const ParamType&)>;
30 void Register(const BackendId& id, FactoryFunction factory)
32 if (m_Factories.count(id) > 0)
34 throw InvalidArgumentException(
35 std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
39 m_Factories[id] = factory;
42 FactoryFunction GetFactory(const BackendId& id) const
44 auto it = m_Factories.find(id);
45 if (it == m_Factories.end())
47 throw InvalidArgumentException(
48 std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
55 FactoryFunction GetFactory(const BackendId& id,
56 FactoryFunction defaultFactory) const
58 auto it = m_Factories.find(id);
59 if (it == m_Factories.end())
61 return defaultFactory;
71 return m_Factories.size();
74 BackendIdSet GetBackendIds() const
77 for (const auto& it : m_Factories)
79 result.insert(it.first);
84 std::string GetBackendIdsAsString() const
86 static const std::string delimitator = ", ";
88 std::stringstream output;
89 for (auto& backendId : GetBackendIds())
91 if (output.tellp() != std::streampos(0))
93 output << delimitator;
102 virtual ~RegistryCommon() {}
105 using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;
108 static void Swap(RegistryCommon& instance, FactoryStorage& other)
110 std::swap(instance.m_Factories, other);
114 RegistryCommon(const RegistryCommon&) = delete;
115 RegistryCommon& operator=(const RegistryCommon&) = delete;
117 FactoryStorage m_Factories;
120 template <typename RegistryType>
121 struct StaticRegistryInitializer
123 using FactoryFunction = typename RegistryType::FactoryFunction;
125 StaticRegistryInitializer(RegistryType& instance,
127 FactoryFunction factory)
129 instance.Register(id, factory);