2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
22 template<typename T, typename U>
23 class singleton_map : public std::map<T, U> {
24 singleton_map() : std::map<T, U>() {};
25 singleton_map(singleton_map const&) = delete;
26 void operator=(singleton_map const&) = delete;
29 static singleton_map &instance() {
30 static singleton_map instance_;
39 struct custom_gpu_primitive;
47 struct primitive_impl;
49 template <class PType>
50 struct typed_program_node;
52 template<typename primitive_kind>
53 struct implementation_key
55 typedef std::tuple<engine_types, data_types, format::type> type;
56 type operator()(engine_types engine_type, const typed_program_node<primitive_kind>& primitive)
58 return std::make_tuple(engine_type, primitive.get_dependency(0).get_output_layout().data_type, primitive.get_dependency(0).get_output_layout().format);
60 type operator()(engine_types engine_type, const layout& proposed_layout)
62 return std::make_tuple(engine_type, proposed_layout.data_type, proposed_layout.format);
67 struct implementation_key<permute>
69 typedef cldnn::engine_types type;
70 type operator()(engine_types engine_type, const typed_program_node<permute>&)
74 type operator()(engine_types engine_type, const layout&)
81 struct implementation_key<reorder>
83 typedef cldnn::engine_types type;
84 type operator()(engine_types engine_type, const typed_program_node<reorder>&)
88 type operator()(engine_types engine_type, const layout&)
96 struct implementation_key<generic_layer>
98 typedef cldnn::engine_types type;
99 type operator()(engine_types engine_type, const typed_program_node<generic_layer>&)
103 type operator()(engine_types engine_type, const layout&)
111 struct implementation_key<custom_gpu_primitive>
113 typedef cldnn::engine_types type;
114 type operator()(engine_types engine_type, const typed_program_node<custom_gpu_primitive>&)
118 type operator()(engine_types engine_type, const layout&)
126 struct implementation_key<reshape>
128 typedef cldnn::engine_types type;
129 type operator()(engine_types engine_type, const typed_program_node<reshape>&)
133 type operator()(engine_types engine_type, const layout&)
141 struct implementation_key<data>
143 typedef cldnn::engine_types type;
144 type operator()(engine_types engine_type, const typed_program_node<data>&)
148 type operator()(engine_types engine_type, const layout&)
156 struct implementation_key<mutable_data>
158 typedef cldnn::engine_types type;
159 type operator()(engine_types engine_type, const typed_program_node<mutable_data>&)
163 type operator()(engine_types engine_type, const layout&)
170 struct implementation_key<input_layout>
172 typedef cldnn::engine_types type;
173 type operator()(engine_types engine_type, const typed_program_node<input_layout>&)
177 type operator()(engine_types engine_type, const layout&)
185 struct implementation_key<prior_box>
187 typedef cldnn::engine_types type;
188 type operator()(engine_types engine_type, const typed_program_node<prior_box>&)
192 type operator()(engine_types engine_type, const layout&)
198 template<typename primitive_kind>
199 class implementation_map {
201 using key_builder = implementation_key<primitive_kind>;
202 using key_type = typename key_builder::type;
203 using factory_type = std::function<primitive_impl*(const typed_program_node<primitive_kind>&)>;
204 using map_type = singleton_map<key_type, factory_type>;
206 static factory_type get(engine_types engine_type, const typed_program_node<primitive_kind>& primitive) {
207 // lookup in database; throw if not found
208 auto key = key_builder()(engine_type, primitive);
209 auto it = map_type::instance().find(key);
210 if (it == std::end(map_type::instance()))
211 throw std::runtime_error(
212 std::string("implementation_map for ") + typeid(primitive_kind).name()
213 + " could not find any implementation to match key");
214 // create implementation & attach it to result
218 //check if for a given engine and type there exist an implementation
219 static bool check(engine_types engine_type, const typed_program_node<primitive_kind>& primitive)
221 auto key = key_builder()(engine_type, primitive);
222 auto it = map_type::instance().find(key);
223 if (it == std::end(map_type::instance()))
229 //check if there exists a kernel implementation of a primitive with output set it primitive's output layout
230 static bool check_io_eq(engine_types engine_type, const typed_program_node<primitive_kind>& primitive)
232 auto key = key_builder()(engine_type, primitive.get_output_layout());
233 auto it = map_type::instance().find(key);
234 if (it == std::end(map_type::instance()))
240 static void add(typename map_type::key_type key, factory_type factory) {
241 map_type::instance().insert({ key, factory });
244 static void add(std::initializer_list<typename map_type::value_type> il) {
245 map_type::instance().insert(il);