1 // Copyright (c) 2018 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "index_select_kernel_base.h"
18 #include "kernel_selector_utils.h"
21 namespace kernel_selector
23 JitConstants IndexSelectKernelBase::GetJitConstants(const index_select_params& params)
25 JitConstants jit = MakeBaseParamsJitConstants(params);
27 jit.AddConstant(MakeJitConstant("AXES_NUMBER", params.axes.size()));
30 jit.AddConstant(MakeJitConstant("REVERSE", 1));
33 for (size_t i = 0; i < params.axes.size(); i++)
35 std::string size_name = "REVERSE_AXIS_SIZE";
36 size_t size_value = 0;
37 if (params.axes.size() > 1) {
39 ss << "REVERSE_" << toString(params.axes[i]) << "_SIZE";
42 jit.AddConstant(MakeJitConstant(toString(params.axes[i]), ""));
44 if (params.axes[i] == IndexSelectAxis::BATCH)
46 size_value = params.inputs.at(0).Batch().v;
48 else if (params.axes[i] == IndexSelectAxis::X)
50 size_value = params.inputs.at(0).X().v;
52 else if (params.axes[i] == IndexSelectAxis::Y)
54 size_value = params.inputs.at(0).Y().v;
56 else if (params.axes[i] == IndexSelectAxis::FEATURE)
58 size_value = params.inputs.at(0).Feature().v;
61 jit.AddConstant(MakeJitConstant(size_name, size_value));
67 IndexSelectKernelBase::DispatchData IndexSelectKernelBase::SetDefault(const index_select_params& params)
69 const auto& output = params.output;
72 kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
74 std::vector<size_t> global;
76 if(params.axes.size() == 1) {
79 if (params.axes[0] == IndexSelectAxis::BATCH)
81 global = { 1, params.inputs.at(0).Batch().v, output.Feature().v };
83 else if (params.axes[0] == IndexSelectAxis::X)
85 global = { output.Batch().v, params.inputs.at(0).X().v, output.Feature().v };
87 else if (params.axes[0] == IndexSelectAxis::Y)
89 global = { output.Batch().v, params.inputs.at(0).Y().v, output.Feature().v };
91 else if (params.axes[0] == IndexSelectAxis::FEATURE)
93 global = { output.Batch().v, params.inputs.at(0).Feature().v, output.Y().v };
98 const auto indices = params.inputs.at(1).X().v;
100 if (params.axes[0] == IndexSelectAxis::BATCH)
102 global = { 1, indices, output.Feature().v };
104 else if (params.axes[0] == IndexSelectAxis::X || params.axes[0] == IndexSelectAxis::Y)
106 global = { output.Batch().v, indices, output.Feature().v };
108 else if (params.axes[0] == IndexSelectAxis::FEATURE)
110 global = { output.Batch().v, indices, output.Y().v };
118 global = { output.Batch().v, output.Y().v, output.Feature().v };
122 const auto& local = GetOptimalLocalWorkGroupSizes(global);
135 KernelsData IndexSelectKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
137 assert(params.GetType() == KernelType::INDEX_SELECT);
139 const auto& prim_params = static_cast<const index_select_params&>(params); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
141 auto run_info = SetDefault(prim_params);
142 KernelData k_data = KernelData::Default<index_select_params>(params);
144 auto cldnn_jit = GetJitConstants(prim_params);
145 auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
146 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
148 auto& kernel = k_data.kernels[0];
149 FillCLKernelData(kernel, run_info, params.engineInfo, kernelName, jit, entry_point, DEFAULT, false, false, (uint32_t)prim_params.inputs.size());
151 k_data.estimatedTime = estimated_time;