IVGCVSW-2093 Add SpaceToBatchNd layer and corresponding no-op factory implementations
[platform/upstream/armnn.git] / src / backends / backendsCommon / RegistryCommon.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/BackendId.hpp>
8 #include <armnn/Exceptions.hpp>
9
10 #include <functional>
11 #include <memory>
12 #include <sstream>
13 #include <string>
14 #include <unordered_map>
15
16 namespace armnn
17 {
18
19 template <typename RegisteredType>
20 struct RegisteredTypeName
21 {
22     static const char * Name() { return "UNKNOWN"; }
23 };
24
25 template <typename RegisteredType, typename PointerType, typename ParamType>
26 class RegistryCommon
27 {
28 public:
29     using FactoryFunction = std::function<PointerType(const ParamType&)>;
30
31     void Register(const BackendId& id, FactoryFunction factory)
32     {
33         if (m_Factories.count(id) > 0)
34         {
35             throw InvalidArgumentException(
36                 std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
37                 CHECK_LOCATION());
38         }
39
40         m_Factories[id] = factory;
41     }
42
43     FactoryFunction GetFactory(const BackendId& id) const
44     {
45         auto it = m_Factories.find(id);
46         if (it == m_Factories.end())
47         {
48             throw InvalidArgumentException(
49                 std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
50                 CHECK_LOCATION());
51         }
52
53         return it->second;
54     }
55
56     FactoryFunction GetFactory(const BackendId& id,
57                                FactoryFunction defaultFactory) const
58     {
59         auto it = m_Factories.find(id);
60         if (it == m_Factories.end())
61         {
62             return defaultFactory;
63         }
64         else
65         {
66             return it->second;
67         }
68     }
69
70     size_t Size() const
71     {
72         return m_Factories.size();
73     }
74
75     BackendIdSet GetBackendIds() const
76     {
77         BackendIdSet result;
78         for (const auto& it : m_Factories)
79         {
80             result.insert(it.first);
81         }
82         return result;
83     }
84
85     std::string GetBackendIdsAsString() const
86     {
87         static const std::string delimitator = ", ";
88
89         std::stringstream output;
90         for (auto& backendId : GetBackendIds())
91         {
92             if (output.tellp() != std::streampos(0))
93             {
94                 output << delimitator;
95             }
96             output << backendId;
97         }
98
99         return output.str();
100     }
101
102     RegistryCommon() {}
103     virtual ~RegistryCommon() {}
104
105 protected:
106     using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;
107
108     // For testing only
109     static void Swap(RegistryCommon& instance, FactoryStorage& other)
110     {
111         std::swap(instance.m_Factories, other);
112     }
113
114 private:
115     RegistryCommon(const RegistryCommon&) = delete;
116     RegistryCommon& operator=(const RegistryCommon&) = delete;
117
118     FactoryStorage m_Factories;
119 };
120
121 template <typename RegistryType>
122 struct StaticRegistryInitializer
123 {
124     using FactoryFunction = typename RegistryType::FactoryFunction;
125
126     StaticRegistryInitializer(RegistryType& instance,
127                               const BackendId& id,
128                               FactoryFunction factory)
129     {
130         instance.Register(id, factory);
131     }
132 };
133
134 } // namespace armnn