Publishing 2019 R3 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / lstm_dynamic / lstm_dynamic_timeloop_kernel_base.cpp
index 611c383..190d13f 100644 (file)
@@ -106,6 +106,27 @@ LSTM_DynamicTimeloopKernelBase::DispatchData LSTM_DynamicTimeloopKernelBase::Set
     return kd;
 }
 
+void kernel_selector::LSTM_DynamicTimeloopKernelBase::SetKernelArguments(const lstm_dynamic_timeloop_params& params, clKernelData& kernel) const {
+    uint32_t input_idx = 0;
+    kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+    kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+    kernel.arguments.push_back({ ArgumentDescriptor::Types::OUTPUT, 0 });
+    kernel.arguments.push_back({ ArgumentDescriptor::Types::RECURRENT, 0 });
+    if (params.has_hidden) {
+        kernel.arguments.push_back({ ArgumentDescriptor::Types::HIDDEN, 0 });
+    }
+    if (params.has_cell) {
+        kernel.arguments.push_back({ ArgumentDescriptor::Types::CELL, 0 });
+    }
+    if (params.has_last_hidden_output) {
+        kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+    }
+    if (params.has_last_cell_output) {
+        kernel.arguments.push_back({ ArgumentDescriptor::Types::INPUT, input_idx++ });
+    }
+}
+
+
 KernelsData LSTM_DynamicTimeloopKernelBase::GetCommonKernelsData(const Params& params,
                                                                  const optional_params& options,
                                                                  float estimated_time) const {
@@ -126,23 +147,7 @@ KernelsData LSTM_DynamicTimeloopKernelBase::GetCommonKernelsData(const Params& p
     kernel.workGroups.global = {run_info.gws0, run_info.gws1, run_info.gws2};
     kernel.workGroups.local  = {run_info.lws0, run_info.lws1, run_info.lws2};
     kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo);
-    uint32_t input_idx = 0;
-    kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
-    kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
-    kernel.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, 0});
-    kernel.arguments.push_back({ArgumentDescriptor::Types::RECURRENT, 0});
-    if (org_params.has_hidden) {
-        kernel.arguments.push_back({ArgumentDescriptor::Types::HIDDEN, 0});
-    }
-    if (org_params.has_cell) {
-        kernel.arguments.push_back({ArgumentDescriptor::Types::CELL, 0});
-    }
-    if (org_params.has_last_hidden_output) {
-        kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
-    }
-    if (org_params.has_last_cell_output) {
-        kernel.arguments.push_back({ArgumentDescriptor::Types::INPUT, input_idx++});
-    }
+    SetKernelArguments(org_params, kernel);
     k_data.estimatedTime = estimated_time;
     return {k_data};
 }