Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / lookup_table / lookup_table_kernel_axis.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "lookup_table_kernel_axis.h"
18
19 namespace kernel_selector
20 {
21     ParamsKey LookUpTableKernelAxis::GetSupportedKey() const
22     {
23         ParamsKey k;
24         k.EnableInputDataType(Datatype::F16);
25         k.EnableInputDataType(Datatype::F32);
26         k.EnableInputDataType(Datatype::INT8);
27         k.EnableOutputDataType(Datatype::F32);
28         k.EnableOutputDataType(Datatype::F16);
29         k.EnableOutputDataType(Datatype::INT8);
30         k.EnableLookUpTableIndicesFormat(Datatype::F32);
31         k.EnableInputLayout(DataLayout::bfyx);
32         k.EnableOutputLayout(DataLayout::bfyx);
33         k.EnableLookUpTableAxis(LookUpTableAxis::BATCH);
34         k.EnableLookUpTableAxis(LookUpTableAxis::X);
35         k.EnableLookUpTableAxis(LookUpTableAxis::Y);
36         k.EnableLookUpTableAxis(LookUpTableAxis::FEATURE);
37         k.EnableBatching();
38         return k;
39     }
40
41     KernelsData LookUpTableKernelAxis::GetKernelsData(const Params& params, const optional_params& options) const
42     {
43         if (!Validate(params, options))
44         {
45             return{};
46         }
47
48         const lookup_table_params& orgParams = static_cast<const lookup_table_params&>(params);
49
50         DispatchData runInfo;
51         runInfo.fp16UnitUsed = orgParams.inputs[0].GetDType() == Datatype::F16;
52
53         if (orgParams.lookUpTableAxis == LookUpTableAxis::BATCH) {
54             runInfo.gws0 = orgParams.inputs[0].X().v;
55             runInfo.gws1 = orgParams.inputs[0].Y().v;
56             runInfo.gws2 = orgParams.inputs[0].Feature().v;
57         }
58         else if (orgParams.lookUpTableAxis == LookUpTableAxis::FEATURE) {
59             runInfo.gws0 = orgParams.inputs[0].X().v;
60             runInfo.gws1 = orgParams.inputs[0].Y().v;
61             runInfo.gws2 = orgParams.inputs[0].Batch().v;
62         }
63         else if (orgParams.lookUpTableAxis == LookUpTableAxis::Y) {
64             runInfo.gws0 = orgParams.inputs[0].X().v;
65             runInfo.gws1 = orgParams.inputs[0].Feature().v;
66             runInfo.gws2 = orgParams.inputs[0].Batch().v;
67         }
68         else if (orgParams.lookUpTableAxis == LookUpTableAxis::X) {
69             runInfo.gws0 = orgParams.inputs[0].Y().v;
70             runInfo.gws1 = orgParams.inputs[0].Feature().v;
71             runInfo.gws2 = orgParams.inputs[0].Batch().v;
72         }
73
74         runInfo.lws0 = std::min(std::max(runInfo.gws0, static_cast<size_t>(1)), static_cast<size_t>(32));
75         while (runInfo.gws0 % runInfo.lws0 != 0)
76         {
77             --runInfo.lws0;
78         }
79         runInfo.lws1 = 1;
80         runInfo.lws2 = 1;
81
82         KernelData kd = KernelData::Default<lookup_table_params>(params);
83
84         auto cldnn_jit = GetJitConstants(orgParams);
85         auto entry_point = GetEntryPoint(kernelName, orgParams.layerID, options);
86         auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
87
88         auto& kernel = kd.kernels[0];
89         FillCLKernelData(kernel, runInfo, kernelName, jit, entry_point, "", false, false, 2);
90
91         kd.estimatedTime = FORCE_PRIORITY_9;
92
93         return{ kd };
94     }
95 }