namespace
{
- inline kernel_selector::IndexSelectAxis convect_to_index_select_axis(index_select_axis_name axis)
+ inline std::vector<kernel_selector::IndexSelectAxis> convert_to_index_select_axis(std::vector<index_select_axis_name> axes)
{
- switch (axis)
+ std::vector<kernel_selector::IndexSelectAxis> axes_names = {};
+ for (size_t i = 0; i < axes.size(); i++)
{
- case index_select_axis_name::along_b: return kernel_selector::IndexSelectAxis::BATCH;
- case index_select_axis_name::along_f: return kernel_selector::IndexSelectAxis::FEATURE;
- case index_select_axis_name::along_x: return kernel_selector::IndexSelectAxis::X;
- case index_select_axis_name::along_y: return kernel_selector::IndexSelectAxis::Y;
- default:
- return kernel_selector::IndexSelectAxis::BATCH;
+ switch (axes[i])
+ {
+ case index_select_axis_name::along_b: axes_names.push_back(kernel_selector::IndexSelectAxis::BATCH); break;
+ case index_select_axis_name::along_f: axes_names.push_back(kernel_selector::IndexSelectAxis::FEATURE); break;
+ case index_select_axis_name::along_x: axes_names.push_back(kernel_selector::IndexSelectAxis::X); break;
+ case index_select_axis_name::along_y: axes_names.push_back(kernel_selector::IndexSelectAxis::Y); break;
+ default:
+ axes_names.push_back(kernel_selector::IndexSelectAxis::BATCH); break;
+ }
}
+ return axes_names;
}
}
auto index_select_params = get_default_params<kernel_selector::index_select_params>(arg, 1);
auto index_select_optional_params = get_default_optional_params<kernel_selector::index_select_optional_params>(arg.get_program());
- index_select_params.inputs.push_back(convert_data_tensor(arg.indices().get_output_layout()));
- index_select_params.axis = convect_to_index_select_axis(arg.get_axis());
+ if (!arg.get_reverse())
+ index_select_params.inputs.push_back(convert_data_tensor(arg.indices().get_output_layout()));
+
+ index_select_params.axes = convert_to_index_select_axis(arg.get_axes());
+ index_select_params.reverse = arg.get_reverse();
auto& kernel_selector = kernel_selector::index_select_kernel_selector::Instance();
auto best_kernels = kernel_selector.GetBestKernels(index_select_params, index_select_optional_params);