Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / src / gpu / index_select_gpu.cpp
index 0dab915..41f826a 100644 (file)
@@ -26,17 +26,22 @@ namespace cldnn { namespace gpu {
 
 namespace
 {
-    inline kernel_selector::IndexSelectAxis convect_to_index_select_axis(index_select_axis_name axis)
+    inline std::vector<kernel_selector::IndexSelectAxis> convert_to_index_select_axis(std::vector<index_select_axis_name> axes)
     {
-        switch (axis)
+        std::vector<kernel_selector::IndexSelectAxis> axes_names = {};
+        for (size_t i = 0; i < axes.size(); i++)
         {
-        case index_select_axis_name::along_b:  return kernel_selector::IndexSelectAxis::BATCH;
-        case index_select_axis_name::along_f:  return kernel_selector::IndexSelectAxis::FEATURE;
-        case index_select_axis_name::along_x:  return kernel_selector::IndexSelectAxis::X;
-        case index_select_axis_name::along_y: return kernel_selector::IndexSelectAxis::Y;
-        default:
-            return kernel_selector::IndexSelectAxis::BATCH;
+            switch (axes[i])
+            {
+            case index_select_axis_name::along_b:  axes_names.push_back(kernel_selector::IndexSelectAxis::BATCH); break;
+            case index_select_axis_name::along_f:  axes_names.push_back(kernel_selector::IndexSelectAxis::FEATURE); break;
+            case index_select_axis_name::along_x:  axes_names.push_back(kernel_selector::IndexSelectAxis::X); break;
+            case index_select_axis_name::along_y:  axes_names.push_back(kernel_selector::IndexSelectAxis::Y); break;
+            default:
+                axes_names.push_back(kernel_selector::IndexSelectAxis::BATCH); break;
+            }
         }
+        return axes_names;
     }
 }
 
@@ -50,8 +55,11 @@ struct index_select_gpu : typed_primitive_gpu_impl<index_select>
         auto index_select_params          = get_default_params<kernel_selector::index_select_params>(arg, 1);
         auto index_select_optional_params = get_default_optional_params<kernel_selector::index_select_optional_params>(arg.get_program());
 
-        index_select_params.inputs.push_back(convert_data_tensor(arg.indices().get_output_layout()));
-        index_select_params.axis = convect_to_index_select_axis(arg.get_axis());
+        if (!arg.get_reverse())
+            index_select_params.inputs.push_back(convert_data_tensor(arg.indices().get_output_layout()));
+
+        index_select_params.axes = convert_to_index_select_axis(arg.get_axes());
+        index_select_params.reverse = arg.get_reverse();
         
         auto& kernel_selector = kernel_selector::index_select_kernel_selector::Instance();
         auto best_kernels = kernel_selector.GetBestKernels(index_select_params, index_select_optional_params);