2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "lookup_table_kernel_axis.h"
19 namespace kernel_selector
21 ParamsKey LookUpTableKernelAxis::GetSupportedKey() const
24 k.EnableInputDataType(Datatype::F16);
25 k.EnableInputDataType(Datatype::F32);
26 k.EnableInputDataType(Datatype::INT8);
27 k.EnableOutputDataType(Datatype::F32);
28 k.EnableOutputDataType(Datatype::F16);
29 k.EnableOutputDataType(Datatype::INT8);
30 k.EnableLookUpTableIndicesFormat(Datatype::F32);
31 k.EnableInputLayout(DataLayout::bfyx);
32 k.EnableOutputLayout(DataLayout::bfyx);
33 k.EnableLookUpTableAxis(LookUpTableAxis::BATCH);
34 k.EnableLookUpTableAxis(LookUpTableAxis::X);
35 k.EnableLookUpTableAxis(LookUpTableAxis::Y);
36 k.EnableLookUpTableAxis(LookUpTableAxis::FEATURE);
41 KernelsData LookUpTableKernelAxis::GetKernelsData(const Params& params, const optional_params& options) const
43 if (!Validate(params, options))
48 const lookup_table_params& orgParams = static_cast<const lookup_table_params&>(params);
51 runInfo.fp16UnitUsed = orgParams.inputs[0].GetDType() == Datatype::F16;
53 if (orgParams.lookUpTableAxis == LookUpTableAxis::BATCH) {
54 runInfo.gws0 = orgParams.inputs[0].X().v;
55 runInfo.gws1 = orgParams.inputs[0].Y().v;
56 runInfo.gws2 = orgParams.inputs[0].Feature().v;
58 else if (orgParams.lookUpTableAxis == LookUpTableAxis::FEATURE) {
59 runInfo.gws0 = orgParams.inputs[0].X().v;
60 runInfo.gws1 = orgParams.inputs[0].Y().v;
61 runInfo.gws2 = orgParams.inputs[0].Batch().v;
63 else if (orgParams.lookUpTableAxis == LookUpTableAxis::Y) {
64 runInfo.gws0 = orgParams.inputs[0].X().v;
65 runInfo.gws1 = orgParams.inputs[0].Feature().v;
66 runInfo.gws2 = orgParams.inputs[0].Batch().v;
68 else if (orgParams.lookUpTableAxis == LookUpTableAxis::X) {
69 runInfo.gws0 = orgParams.inputs[0].Y().v;
70 runInfo.gws1 = orgParams.inputs[0].Feature().v;
71 runInfo.gws2 = orgParams.inputs[0].Batch().v;
74 runInfo.lws0 = std::min(std::max(runInfo.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
75 while (runInfo.gws0 % runInfo.lws0 != 0)
82 KernelData kd = KernelData::Default<lookup_table_params>(params);
84 auto cldnn_jit = GetJitConstants(orgParams);
85 auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
86 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
88 auto& kernel = kd.kernels[0];
89 FillCLKernelData(kernel, runInfo, kernelName, jit, entry_point, "", false, false, 2);
91 kd.estimatedTime = FORCE_PRIORITY_9;