1 // Copyright (c) 2019 Intel Corporation
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
16 #include "one_hot_kernel_base.h"
18 #include "kernel_selector_utils.h"
21 namespace kernel_selector
23 JitConstants OneHotKernelBase::GetJitConstants(const one_hot_params& params)
25 JitConstants jit = MakeBaseParamsJitConstants(params);
28 MakeJitConstant("ONE_HOT_AXIS", params.one_hot_axis),
29 MakeJitConstant("ONE_HOT_LIMIT", params.one_hot_limit)
35 OneHotKernelBase::DispatchData OneHotKernelBase::SetDefault(const one_hot_params& params)
37 const auto& input = params.inputs[0];
41 kd.fp16UnitUsed = input.GetDType() == Datatype::F16;
43 std::vector<size_t> global{ input.Feature().v, input.Y().v, input.X().v };
44 const auto& local = GetOptimalLocalWorkGroupSizes(global);
57 KernelsData OneHotKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
59 assert(params.GetType() == KernelType::ONE_HOT);
61 const auto& prim_params = static_cast<const one_hot_params&>(params); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
63 auto run_info = SetDefault(prim_params);
64 KernelData k_data = KernelData::Default<one_hot_params>(params);
66 auto cldnn_jit = GetJitConstants(prim_params);
67 auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
68 auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
70 auto& kernel = k_data.kernels[0];
71 FillCLKernelData(kernel, run_info, params.engineInfo, kernelName, jit, entry_point);
72 k_data.estimatedTime = estimated_time;