IVGCVSW-2093 Add SpaceToBatchNd layer and corresponding no-op factory implementations
[platform/upstream/armnn.git] / src / backends / RegistryCommon.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/BackendId.hpp>
8 #include <armnn/Exceptions.hpp>
9 #include <functional>
10 #include <memory>
11 #include <sstream>
12 #include <string>
13 #include <unordered_map>
14
15 namespace armnn
16 {
17
18 template <typename RegisteredType>
19 struct RegisteredTypeName
20 {
21     static const char * Name() { return "UNKNOWN"; }
22 };
23
24 template <typename RegisteredType, typename PointerType, typename ParamType>
25 class RegistryCommon
26 {
27 public:
28     using FactoryFunction = std::function<PointerType(const ParamType&)>;
29
30     void Register(const BackendId& id, FactoryFunction factory)
31     {
32         if (m_Factories.count(id) > 0)
33         {
34             throw InvalidArgumentException(
35                 std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
36                 CHECK_LOCATION());
37         }
38
39         m_Factories[id] = factory;
40     }
41
42     FactoryFunction GetFactory(const BackendId& id) const
43     {
44         auto it = m_Factories.find(id);
45         if (it == m_Factories.end())
46         {
47             throw InvalidArgumentException(
48                 std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
49                 CHECK_LOCATION());
50         }
51
52         return it->second;
53     }
54
55     FactoryFunction GetFactory(const BackendId& id,
56                                FactoryFunction defaultFactory) const
57     {
58         auto it = m_Factories.find(id);
59         if (it == m_Factories.end())
60         {
61             return defaultFactory;
62         }
63         else
64         {
65             return it->second;
66         }
67     }
68
69     size_t Size() const
70     {
71         return m_Factories.size();
72     }
73
74     BackendIdSet GetBackendIds() const
75     {
76         BackendIdSet result;
77         for (const auto& it : m_Factories)
78         {
79             result.insert(it.first);
80         }
81         return result;
82     }
83
84     std::string GetBackendIdsAsString() const
85     {
86         static const std::string delimitator = ", ";
87
88         std::stringstream output;
89         for (auto& backendId : GetBackendIds())
90         {
91             if (output.tellp() != std::streampos(0))
92             {
93                 output << delimitator;
94             }
95             output << backendId;
96         }
97
98         return output.str();
99     }
100
101     RegistryCommon() {}
102     virtual ~RegistryCommon() {}
103
104 protected:
105     using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;
106
107     // For testing only
108     static void Swap(RegistryCommon& instance, FactoryStorage& other)
109     {
110         std::swap(instance.m_Factories, other);
111     }
112
113 private:
114     RegistryCommon(const RegistryCommon&) = delete;
115     RegistryCommon& operator=(const RegistryCommon&) = delete;
116
117     FactoryStorage m_Factories;
118 };
119
120 template <typename RegistryType>
121 struct StaticRegistryInitializer
122 {
123     using FactoryFunction = typename RegistryType::FactoryFunction;
124
125     StaticRegistryInitializer(RegistryType& instance,
126                               const BackendId& id,
127                               FactoryFunction factory)
128     {
129         instance.Register(id, factory);
130     }
131 };
132
133 } // namespace armnn