Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / primitive_gpu_base.h
index 8343147..704b83e 100644 (file)
@@ -18,6 +18,7 @@
 #pragma once
 
 #include "primitive_inst.h"
+#include "program_impl.h"
 #include "kernel.h"
 #include "events_waiter.h"
 #include "error_handler.h"
@@ -30,8 +31,8 @@ namespace cldnn { namespace gpu
 bool is_any_user_cpu(const std::list<const program_node*>& users);
 
 /*
-Base class for all implementation of specified primitive type.
-For example, all convolution implementations should derive from typed_primitive_impl<convolution>.
+Base class for all GPU implementation of specified primitive type.
+For example, all gpu convolution implementations should derive from typed_primitive_gpu_impl<convolution>.
 */
 template <class PType>
 struct typed_primitive_gpu_impl : public typed_primitive_impl<PType>
@@ -67,13 +68,11 @@ struct typed_primitive_gpu_impl : public typed_primitive_impl<PType>
             auto& eimpl = arg.get_program().get_engine();
             _intermediates_memory.push_back(eimpl.allocate_memory(expected_layout));
         }
-    }
-protected:
 
-    virtual bool validate(typed_primitive_inst<PType>&) const
-    {
-        return true;
     }
+    bool is_cpu() const override { return false; }
+
+protected:
 
     virtual bool optimized_out(typed_primitive_inst<PType>&) const
     {
@@ -99,6 +98,11 @@ protected:
         return 1;
     }
 
+    virtual uint32_t get_groups() const
+    {
+        return 1;
+    }
+
     event_impl::ptr aggregate_events(const std::vector<event_impl::ptr>& events, bool group=false) const
     {
         if (events.size() == 1)
@@ -112,9 +116,6 @@ protected:
 
     virtual event_impl::ptr execute_impl(const std::vector<event_impl::ptr>& events, typed_primitive_inst<PType>& instance) override
     {
-        const bool validated = validate(instance);
-        CLDNN_ERROR_NOT_EQUAL(_outer.id(), "validate", validated, "", true, "not a valid instance.");
-
         if (optimized_out(instance))
         {
             return aggregate_events(events);
@@ -124,6 +125,9 @@ protected:
 
         // TODO - split should be handle in kernel selector by providing multiple kernels.
         auto split = get_split();
+        auto groups = get_groups();
+        if (split == 1)
+            split = groups;
 
         // we iterate over split first in order to be able parallelism with OOOQ mechanism.
         for (size_t k = 0; k < _kernels.size(); ++k)