2 // Copyright (c) 2016 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 ///////////////////////////////////////////////////////////////////////////////////////////////////
20 #include "primitive_inst.h"
21 #include "program_impl.h"
23 #include "events_waiter.h"
24 #include "error_handler.h"
25 #include "kernel_selector_helper.h"
27 namespace cldnn { namespace gpu
30 // checks if any user in a list is a cpu primitive
31 bool is_any_user_cpu(const std::list<const program_node*>& users);
34 Base class for all GPU implementation of specified primitive type.
35 For example, all gpu convolution implementations should derive from typed_primitive_gpu_impl<convolution>.
37 template <class PType>
38 struct typed_primitive_gpu_impl : public typed_primitive_impl<PType>
40 const typed_program_node<PType>& _outer;
41 engine_info_internal _engine_info;
42 kernel_selector::kernel_data _kernel_data;
43 std::vector<gpu::kernel> _kernels;
44 std::vector<memory_impl::cptr> _intermediates_memory;
46 typed_primitive_gpu_impl(const typed_program_node<PType>& arg, const kernel_selector::kernel_data& kd)
47 : typed_primitive_impl<PType>(kd.weightsReorderParams, kd.kernelName)
49 , _engine_info(arg.get_program().get_engine().get_context()->get_engine_info())
52 _kernels.reserve(kd.kernels.size());
53 for (size_t i = 0; i < kd.kernels.size(); ++i)
55 gpu::kernel kernel(_outer.get_program().get_engine().get_context(), kd.kernels[i].kernelString);
56 _kernels.emplace_back(std::move(kernel));
59 for (auto size : kd.internalBufferSizes)
61 auto dtype = arg.input().get_output_layout().data_type;
62 const auto bpp = data_type_traits::size_of(dtype);
63 layout expected_layout = {
64 dtype, format::bfyx, // simple linear format (flatten to x channel)
65 { 1,1,1,(tensor::value_type)(size / bpp) }
68 auto& eimpl = arg.get_program().get_engine();
69 _intermediates_memory.push_back(eimpl.allocate_memory(expected_layout));
73 bool is_cpu() const override { return false; }
77 virtual bool optimized_out(typed_primitive_inst<PType>&) const
82 virtual kernel::kernel_arguments_data get_arguments(typed_primitive_inst<PType>& instance, int32_t /*split*/) const
84 kernel::kernel_arguments_data args;
86 for (size_t i = 0; i < instance.inputs_memory_count(); i++)
88 args.inputs.push_back(&instance.input_memory(i));
91 args.output = &instance.output_memory();
96 virtual int32_t get_split() const
101 virtual uint32_t get_groups() const
106 event_impl::ptr aggregate_events(const std::vector<event_impl::ptr>& events, bool group=false) const
108 if (events.size() == 1)
112 return _outer.get_program().get_engine().get_context()->group_events(events);
114 return events_waiter(_outer.get_program().get_engine().get_context()).run(events);
117 virtual event_impl::ptr execute_impl(const std::vector<event_impl::ptr>& events, typed_primitive_inst<PType>& instance) override
119 if (optimized_out(instance))
121 return aggregate_events(events);
124 std::vector<event_impl::ptr> tmp_events(events);
126 // TODO - split should be handle in kernel selector by providing multiple kernels.
127 auto split = get_split();
128 auto groups = get_groups();
132 // we iterate over split first in order to be able parallelism with OOOQ mechanism.
133 for (size_t k = 0; k < _kernels.size(); ++k)
135 std::vector<event_impl::ptr> new_events;
136 for (decltype(split) i = 0; i < split; i++)
138 auto args = get_arguments(instance, i);
139 args.scalars = &_kernel_data.kernels[k].scalars;
142 for (const auto& m : _intermediates_memory)
144 args.intermediates.push_back(m);
147 //is any user of the prim's users is an detecion output, set prim as a output event (event won't be nullptr)
148 auto users = instance.node.get_users();
149 bool next_prim_is_cpu = is_any_user_cpu(users);
150 if (next_prim_is_cpu)
152 _kernels[k].set_output_event(true);
156 _kernels[k].set_output_event(instance.node.is_output());
159 auto event = _kernels[k].run(_kernel_data.kernels[k], tmp_events, args);
160 new_events.push_back(event);
163 tmp_events = new_events;
166 bool group_events = split > 1 ? true : false;
167 return aggregate_events(tmp_events, group_events);