Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / include / implementation_map.h
index 5fc2710..7472038 100644 (file)
@@ -57,6 +57,10 @@ struct implementation_key
     {
         return std::make_tuple(engine_type, primitive.get_dependency(0).get_output_layout().data_type, primitive.get_dependency(0).get_output_layout().format);
     }
+    type operator()(engine_types engine_type, const layout& proposed_layout)
+    {
+        return std::make_tuple(engine_type, proposed_layout.data_type, proposed_layout.format);
+    }
 };
 
 template<>
@@ -67,6 +71,10 @@ struct implementation_key<permute>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
 };
 
 template<>
@@ -77,6 +85,11 @@ struct implementation_key<reorder>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
+
 };
 
 template<>
@@ -87,6 +100,11 @@ struct implementation_key<generic_layer>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
+
 };
 
 template<>
@@ -97,6 +115,11 @@ struct implementation_key<custom_gpu_primitive>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
+
 };
 
 template<>
@@ -107,6 +130,11 @@ struct implementation_key<reshape>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
+
 };
 
 template<>
@@ -117,6 +145,11 @@ struct implementation_key<data>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
+
 };
 
 template<>
@@ -127,6 +160,10 @@ struct implementation_key<mutable_data>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
 };
 
 template<>
@@ -137,6 +174,11 @@ struct implementation_key<input_layout>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
+
 };
 
 template<>
@@ -147,6 +189,10 @@ struct implementation_key<prior_box>
     {
         return engine_type;
     }
+    type operator()(engine_types engine_type, const layout&)
+    {
+        return engine_type;
+    }
 };
 
 template<typename primitive_kind>
@@ -162,12 +208,35 @@ public:
         auto key = key_builder()(engine_type, primitive);
         auto it = map_type::instance().find(key);
         if (it == std::end(map_type::instance())) 
-            throw std::runtime_error(std::string("implementation_map for ")+typeid(primitive_kind).name() +" could not find any implementation to match key");
-
+            throw std::runtime_error(
+                std::string("implementation_map for ") + typeid(primitive_kind).name()
+                    + " could not find any implementation to match key");
         // create implementation & attach it to result 
         return it->second;
     }
 
+    //check if for a given engine and type there exist an implementation
+    static bool check(engine_types engine_type, const typed_program_node<primitive_kind>& primitive)
+    {
+        auto key = key_builder()(engine_type, primitive);
+        auto it = map_type::instance().find(key);
+        if (it == std::end(map_type::instance()))
+            return false;
+        else
+            return true;
+    }
+
+    //check if there exists a kernel implementation of a primitive with output set it primitive's output layout
+    static bool check_io_eq(engine_types engine_type, const typed_program_node<primitive_kind>& primitive)
+    {
+        auto key = key_builder()(engine_type, primitive.get_output_layout());
+        auto it = map_type::instance().find(key);
+        if (it == std::end(map_type::instance()))
+            return false;
+        else
+            return true;
+    }
+
     static void add(typename map_type::key_type key, factory_type factory) {
         map_type::instance().insert({ key, factory });
     }