2 // Copyright (c) 2018 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "select_kernel_base.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
23 bool SelectKernelBase::Validate(const Params& p, const optional_params& o) const
25 if (p.GetType() != KernelType::SELECT ||
26 o.GetType() != KernelType::SELECT)
31 const select_params& params = static_cast<const select_params&>(p);
33 if (params.inputs[0].GetDType() != params.inputs[1].GetDType())
38 if (params.inputs.size() != 3)
46 JitConstants SelectKernelBase::GetJitConstantsCommon(const select_params& params) const
48 JitConstants jit = MakeBaseParamsJitConstants(params);
50 std::string inputs_decls;
52 for (size_t i = 0; i < params.inputs.size(); i++)
54 std::string const_str = "const";
56 inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
59 jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
61 std::string destType, absType;
67 if ((params.inputs[2].GetDType() == Datatype::INT8
68 || params.inputs[2].GetDType() == Datatype::UINT8)
69 && (params.inputs[0].GetDType() == Datatype::INT8
70 || params.inputs[0].GetDType() == Datatype::UINT8))
72 jit.AddConstant(MakeJitConstant("MASK", "INPUT_2"));
78 if (params.inputs[2].GetDType() == Datatype::F32
79 || params.inputs[2].GetDType() == Datatype::F16)
93 if (params.inputs[0].GetDType() == Datatype::F32) {
97 else if (params.inputs[0].GetDType() == Datatype::F16) {
109 jit.AddConstant(MakeJitConstant("MASK", "convert_" + destType + "_rtp(" + absType + "(INPUT_2))"));
115 JitConstants SelectKernelBase::GetJitConstants(const select_params& params) const
117 return GetJitConstantsCommon(params);
120 SelectKernelBase::DispatchData SelectKernelBase::SetDefault(const select_params& params) const
124 const auto& out = params.output;
126 std::vector<size_t> gws;
127 for (const auto& o : out.GetDims())
132 for (size_t i = gws.size(); i < 4; i++)
139 kd.gws2 = gws[2] * gws[3];
141 auto local = GetOptimalLocalWorkGroupSizes( { kd.gws0, kd.gws1, kd.gws2 } );
149 KernelsData SelectKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
151 if (!Validate(params, options))
156 KernelData kd = KernelData::Default<select_params>(params);
157 select_params& newParams = *static_cast<select_params*>(kd.params.get());
159 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
160 auto cldnn_jit = GetJitConstants(newParams);
161 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
163 DispatchData runInfo = SetDefault(newParams);
165 auto& kernel = kd.kernels[0];
167 kernel.workGroups.global = { runInfo.gws0, runInfo.gws1, runInfo.gws2 };
168 kernel.workGroups.local = { runInfo.lws0, runInfo.lws1, runInfo.lws2 };
170 kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
171 kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(), false, false);
173 kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;