Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / one_hot / one_hot_kernel_base.cpp
1 // Copyright (c) 2019 Intel Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15
16 #include "one_hot_kernel_base.h"
17
18 #include "kernel_selector_utils.h"
19
20
21 namespace kernel_selector
22 {
23     JitConstants OneHotKernelBase::GetJitConstants(const one_hot_params& params)
24     {
25         JitConstants jit = MakeBaseParamsJitConstants(params);
26
27         jit.AddConstants({
28             MakeJitConstant("ONE_HOT_AXIS", params.one_hot_axis),
29             MakeJitConstant("ONE_HOT_LIMIT", params.one_hot_limit)
30         });
31
32         return jit;
33     }
34
35     OneHotKernelBase::DispatchData OneHotKernelBase::SetDefault(const one_hot_params& params)
36     {
37         const auto& input = params.inputs[0];
38
39         DispatchData kd;
40
41         kd.fp16UnitUsed = input.GetDType() == Datatype::F16;
42
43         std::vector<size_t> global{ input.Feature().v, input.Y().v, input.X().v };
44         const auto& local = GetOptimalLocalWorkGroupSizes(global);
45
46         kd.gws0 = global[0];
47         kd.gws1 = global[1];
48         kd.gws2 = global[2];
49
50         kd.lws0 = local[0];
51         kd.lws1 = local[1];
52         kd.lws2 = local[2];
53
54         return kd;
55     }
56
57     KernelsData OneHotKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimated_time) const
58     {
59         assert(params.GetType() == KernelType::ONE_HOT);
60
61         const auto& prim_params = static_cast<const one_hot_params&>(params); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
62
63         auto run_info = SetDefault(prim_params);
64         KernelData k_data = KernelData::Default<one_hot_params>(params);
65
66         auto cldnn_jit = GetJitConstants(prim_params);
67         auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, options);
68         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
69
70         auto& kernel = k_data.kernels[0];
71         FillCLKernelData(kernel, run_info, params.engineInfo, kernelName, jit, entry_point);
72         k_data.estimatedTime = estimated_time;
73
74         return{ k_data };
75     }
76 }