#pragma once
#include "primitive_inst.h"
+#include "program_impl.h"
#include "kernel.h"
#include "events_waiter.h"
#include "error_handler.h"
bool is_any_user_cpu(const std::list<const program_node*>& users);
/*
-Base class for all implementation of specified primitive type.
-For example, all convolution implementations should derive from typed_primitive_impl<convolution>.
+Base class for all GPU implementation of specified primitive type.
+For example, all gpu convolution implementations should derive from typed_primitive_gpu_impl<convolution>.
*/
template <class PType>
struct typed_primitive_gpu_impl : public typed_primitive_impl<PType>
auto& eimpl = arg.get_program().get_engine();
_intermediates_memory.push_back(eimpl.allocate_memory(expected_layout));
}
- }
-protected:
- virtual bool validate(typed_primitive_inst<PType>&) const
- {
- return true;
}
+ bool is_cpu() const override { return false; }
+
+protected:
virtual bool optimized_out(typed_primitive_inst<PType>&) const
{
return 1;
}
+ virtual uint32_t get_groups() const
+ {
+ return 1;
+ }
+
event_impl::ptr aggregate_events(const std::vector<event_impl::ptr>& events, bool group=false) const
{
if (events.size() == 1)
virtual event_impl::ptr execute_impl(const std::vector<event_impl::ptr>& events, typed_primitive_inst<PType>& instance) override
{
- const bool validated = validate(instance);
- CLDNN_ERROR_NOT_EQUAL(_outer.id(), "validate", validated, "", true, "not a valid instance.");
-
if (optimized_out(instance))
{
return aggregate_events(events);
// TODO - split should be handle in kernel selector by providing multiple kernels.
auto split = get_split();
+ auto groups = get_groups();
+ if (split == 1)
+ split = groups;
// we iterate over split first in order to be able parallelism with OOOQ mechanism.
for (size_t k = 0; k < _kernels.size(); ++k)