return kd;
}
+void kernel_selector::LSTM_DynamicTimeloopKernelBase::SetKernelArguments(const lstm_dynamic_timeloop_params& params, clKernelData& kernel) const {
+ uint32_t input_idx = 0;
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::OUTPUT, 0 });
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::RECURRENT, 0 });
+ if (params.has_hidden) {
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::HIDDEN, 0 });
+ }
+ if (params.has_cell) {
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::CELL, 0 });
+ }
+ if (params.has_last_hidden_output) {
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+ }
+ if (params.has_last_cell_output) {
+ kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+ }
+}
+
+
KernelsData LSTM_DynamicTimeloopKernelBase::GetCommonKernelsData(const Params& params,
const optional_params& options,
float estimated_time) const {
kernel.workGroups.global = {run_info.gws0, run_info.gws1, run_info.gws2};
kernel.workGroups.local = {run_info.lws0, run_info.lws1, run_info.lws2};
kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo);
- uint32_t input_idx = 0;
- kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
- kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
- kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
- kernel.arguments.push_back({ArgumentDescriptor::Types::RECURRENT, 0});
- if (org_params.has_hidden) {
- kernel.arguments.push_back({ArgumentDescriptor::Types::HIDDEN, 0});
- }
- if (org_params.has_cell) {
- kernel.arguments.push_back({ArgumentDescriptor::Types::CELL, 0});
- }
- if (org_params.has_last_hidden_output) {
- kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
- }
- if (org_params.has_last_cell_output) {
- kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
- }
+ SetKernelArguments(org_params, kernel);
k_data.estimatedTime = estimated_time;
return {k_data};
}