Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / clDNN / kernel_selector / core / actual_kernels / select / select_kernel_base.cpp
1 /*
2 // Copyright (c) 2018 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16
17 #include "select_kernel_base.h"
18 #include "kernel_selector_utils.h" 
19
20 namespace kernel_selector
21 {
22     
23     bool SelectKernelBase::Validate(const Params& p, const optional_params& o) const
24     {
25         if (p.GetType() != KernelType::SELECT ||
26             o.GetType() != KernelType::SELECT)
27         {
28             return false;
29         }
30
31         const select_params& params = static_cast<const select_params&>(p);
32
33                 if (params.inputs[0].GetDType() != params.inputs[1].GetDType()) 
34                 {
35                         return false;
36                 }
37
38         if (params.inputs.size() != 3)
39         {
40             return false;
41         }
42
43         return true;
44     }
45
46     JitConstants SelectKernelBase::GetJitConstantsCommon(const select_params& params) const
47     {
48         JitConstants jit = MakeBaseParamsJitConstants(params);
49
50         std::string inputs_decls;
51
52         for (size_t i = 0; i < params.inputs.size(); i++)
53         {
54             std::string const_str = "const";
55
56             inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
57         }
58
59         jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
60
61                 std::string destType, absType;
62
63                 // i8, i8, i8
64                 // i8, i8, u8
65                 // u8, u8, i8
66                 // u8, u8, u8
67                 if ((params.inputs[2].GetDType() == Datatype::INT8
68                         || params.inputs[2].GetDType() == Datatype::UINT8)
69                         && (params.inputs[0].GetDType() == Datatype::INT8
70                                 || params.inputs[0].GetDType() == Datatype::UINT8))
71                 {
72                         jit.AddConstant(MakeJitConstant("MASK", "INPUT_2"));
73                 }
74                 else
75                 {
76                         // x, x, f32
77                         // x, x, f16
78                         if (params.inputs[2].GetDType() == Datatype::F32
79                                 || params.inputs[2].GetDType() == Datatype::F16)
80                         {
81                                 absType = "fabs";
82                         }
83                         // f32, f32, i8
84                         // f32, f32, u8
85                         // f16, f16, i8
86                         // f16, f16, u8
87                         else
88                         {
89                                 absType = "abs";
90                         }
91
92                         // f32, f32, x
93                         if (params.inputs[0].GetDType() == Datatype::F32) {
94                                 destType = "int";
95                         }
96                         // f16, f16, x
97                         else if (params.inputs[0].GetDType() == Datatype::F16) {
98                                 destType = "short";
99                         }
100                         // i8, i8, f32
101                         // i8, i8, f16
102                         // u8, u8, f32
103                         // u8, u8, f16
104                         else
105                         {
106                                 destType = "char";
107                         }
108
109                         jit.AddConstant(MakeJitConstant("MASK", "convert_" + destType + "_rtp(" + absType + "(INPUT_2))"));
110                 }
111
112         return jit;
113     }
114
115     JitConstants SelectKernelBase::GetJitConstants(const select_params& params) const
116     {
117         return GetJitConstantsCommon(params);
118     }
119
120     SelectKernelBase::DispatchData SelectKernelBase::SetDefault(const select_params& params) const
121     {
122         DispatchData kd;
123
124         const auto& out = params.output;
125
126         std::vector<size_t> gws;
127         for (const auto& o : out.GetDims())
128         {
129             gws.push_back(o.v);
130         }
131
132         for (size_t i = gws.size(); i < 4; i++)
133         {
134             gws.push_back(1U);
135         }
136
137         kd.gws0 = gws[0];
138         kd.gws1 = gws[1];
139         kd.gws2 = gws[2] * gws[3];
140
141         auto local = GetOptimalLocalWorkGroupSizes( { kd.gws0, kd.gws1, kd.gws2 } );
142         kd.lws0 = local[0];
143         kd.lws1 = local[1];
144         kd.lws2 = local[2];
145
146         return kd;
147     }
148
149     KernelsData SelectKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
150     {
151         if (!Validate(params, options))
152         {
153             return{};
154         }
155
156         KernelData kd = KernelData::Default<select_params>(params);
157         select_params& newParams = *static_cast<select_params*>(kd.params.get());
158
159         auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
160         auto cldnn_jit = GetJitConstants(newParams);
161         std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
162
163         DispatchData runInfo = SetDefault(newParams);
164
165         auto& kernel = kd.kernels[0];
166
167         kernel.workGroups.global = { runInfo.gws0, runInfo.gws1, runInfo.gws2 };
168         kernel.workGroups.local = { runInfo.lws0, runInfo.lws1, runInfo.lws2 };
169
170         kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
171         kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(), false, false);
172
173         kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;
174
175         return{ kd };
176     }
177 }