Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / lookup_table / lookup_table_kernel_base.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "lookup_table_kernel_base.h"
18
19 namespace kernel_selector
20 {
21     bool LookUpTableKernelBase::Validate(const Params& p, const optional_params& o) const
22     {
23         if (p.GetType() != KernelType::LOOKUP_TABLE ||
24             o.GetType() != KernelType::LOOKUP_TABLE)
25         {
26             return false;
27         }
28
29         return true;
30     }
31
32     JitConstants LookUpTableKernelBase::GetJitConstants(const lookup_table_params& params) const
33     {
34         JitConstants jit = MakeBaseParamsJitConstants(params);
35
36         jit.AddConstants({
37             MakeJitConstant("VAL_NUM", params.numberOfValues),
38             MakeJitConstant(toString(params.lookUpTableAxis) + "_AXIS", 1),
39         });
40
41         return jit;
42     }
43
44     LookUpTableKernelBase::DispatchData LookUpTableKernelBase::SetDefault(const lookup_table_params& params) const
45     {
46         DispatchData kd;
47
48         kd.fp16UnitUsed = params.inputs[0].GetDType() == Datatype::F16;
49
50         // Determine global work sizes.
51         kd.gws0 = params.inputIndices.X().v;
52         kd.gws1 = params.inputIndices.Batch().v;                   // B
53         kd.gws2 = 1;
54
55         kd.lws0 = std::min(std::max(kd.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
56         while (kd.gws0 % kd.lws0 != 0)
57         {
58             --kd.lws0;
59         }
60         kd.lws1 = 1;
61         kd.lws2 = 1;
62
63         return kd;
64     }
65
66     KernelsData LookUpTableKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options, float estimatedTime) const
67     {
68         if (!Validate(params, options))
69         {
70             return{};
71         }
72
73         const lookup_table_params& orgParams = static_cast<const lookup_table_params&>(params);
74
75         DispatchData runInfo = SetDefault(orgParams);
76
77         KernelData kd = KernelData::Default<lookup_table_params>(params);
78
79         auto cldnn_jit = GetJitConstants(orgParams);
80         auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
81         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
82
83         auto& kernel = kd.kernels[0];
84         FillCLKernelData(kernel, runInfo, kernelName, jit, entry_point, "", false, false, 2);
85
86         kd.estimatedTime = estimatedTime;
87
88         return{ kd };
89     }
90 }