2 // Copyright (c) 2016-2019 Intel Corporation
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 #include "eltwise_kernel_base.h"
18 #include "kernel_selector_utils.h"
20 namespace kernel_selector
22 static uint32_t GetNumberOfInputs(EltwiseMode m)
26 case EltwiseMode::ADD:
27 case EltwiseMode::SUB:
28 case EltwiseMode::MUL:
29 case EltwiseMode::DIV:
30 case EltwiseMode::MIN:
31 case EltwiseMode::MAX:
32 case EltwiseMode::POW:
33 case EltwiseMode::MODULU:
40 case EltwiseMode::LOGIC_AND:
41 case EltwiseMode::LOGIC_OR:
42 case EltwiseMode::LOGIC_XOR:
43 case EltwiseMode::SQUARED_DIFF:
45 case EltwiseMode::SQRT:
46 case EltwiseMode::RSQRT:
47 case EltwiseMode::ASSIGN:
54 ParamsKey eltwise_params::GetParamsKey() const
56 ParamsKey k = base_params::GetParamsKey();
57 if (int8_quantization)
59 k.EnableInt8Quantization();
62 if (output_calibration)
64 k.EnableOutputCalibration();
69 k.EnableEltwiseStride();
74 k.EnableEltwiseBroadcast();
80 bool EltwiseKernelBase::Validate(const Params& p, const optional_params& o) const
82 if (p.GetType() != KernelType::ELTWISE ||
83 o.GetType() != KernelType::ELTWISE)
88 const eltwise_params& params = static_cast<const eltwise_params&>(p);
90 if (params.inputs.size() == 0)
95 auto& operations = params.operations;
97 if (operations.size() == 0)
102 for (size_t op_num = 0; op_num < operations.size(); op_num++)
104 const auto& ew = operations[op_num];
106 if (ew.inputs.size() != GetNumberOfInputs(ew.mode))
111 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
113 const auto& input = ew.inputs[input_idx];
114 if (input.mode == EltwiseInputMode::INPUT_BUFFER &&
115 input.index >= params.inputs.size())
125 JitConstants EltwiseKernelBase::GetJitConstantsCommon(const eltwise_params& params, bool useVload8) const
127 JitConstants jit = MakeBaseParamsJitConstants(params);
130 MakeJitConstant("ELTWISE_LAYOUT_BASED", params.layoutBased),
131 MakeJitConstant("QUANTIZATION_TERM", params.int8_quantization),
132 MakeJitConstant("ELTWISE_BROADCAST", params.broadcast),
135 if (params.int8_quantization)
137 if (params.output_calibration)
139 jit.AddConstant(MakeJitConstant("CALIBRATION_TERM", params.output_calibration));
140 jit.AddConstant(MakeJitConstant("O_QF", params.output_calibration_factors[0]));
144 jit.AddConstants({ MakeJitConstant("O_QF", params.output_quantization_factor) });
147 std::string inputs_decls, vload_decls;
148 auto& updateInputs = params.updateInputIds;
150 for (size_t i = 0; i < params.inputs.size(); i++)
152 //const should be added only to inputs which will not be updated
153 std::string const_str = "const";
154 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
156 if (updateInputs[update_input_idx].inputId == i)
163 inputs_decls += const_str + " __global " + toCLType(params.inputs[i].GetDType()) + "* input" + std::to_string(i) + ", ";
164 if (!params.stride.empty())
166 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_X", params.stride[i].x));
167 jit.AddConstant(MakeJitConstant("INPUT" + std::to_string(i) + "_STRIDE_Y", params.stride[i].y));
171 vload_decls += "\\\n\tconst " + toCLType(params.inputs[i].GetDType()) + "8 in" + std::to_string(i);
172 if (params.inputs[i].PhysicalSize() == 1) //Scalar case
173 vload_decls += " = (" + toCLType(params.inputs[i].GetDType()) + "8)(input" + std::to_string(i) + "[0]";
175 vload_decls += " = vload8(global_id, input" + std::to_string(i);
180 jit.AddConstant(MakeJitConstant("INPUTS_DECLS", inputs_decls));
181 jit.AddConstant(MakeJitConstant("ELTWISE_NO_PITCH_SAME_DIMS", CheckInputsOutputNoPitchSameDims(params)));
184 jit.AddConstant(MakeJitConstant("VLOAD_DECLS", vload_decls));
186 std::string do_eltwise;
188 auto& operations = params.operations;
189 auto& coefficients = params.coefficients;
191 for (size_t op_num = 0; op_num < operations.size(); op_num++)
193 const std::string op_num_str = std::to_string(op_num);
194 const auto& ew = operations[op_num];
196 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
198 const auto& input = ew.inputs[input_idx];
199 const std::string name = "INPUT_" + op_num_str + "_" + std::to_string(input_idx);
202 case EltwiseInputMode::SCALAR:
203 jit.AddConstant(MakeJitConstant(name, input.scalar));
205 case EltwiseInputMode::INPUT_BUFFER:
207 jit.AddConstant(MakeJitConstant(name, "in" + std::to_string(input.index)));
209 jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[GET_INDEX(INPUT, " + std::to_string(input.index) + ")]"));
211 case EltwiseInputMode::OUTPUT_BUFFER:
212 jit.AddConstant(MakeJitConstant(name, "output[GET_INDEX(OUTPUT, )]"));
214 case EltwiseInputMode::UNORDERED_ACCESS_INPUT_BUFFER:
215 jit.AddConstant(MakeJitConstant(name, "input" + std::to_string(input.index) + "[(size_t)tmp" + std::to_string(input.tmpIndex) + "]"));
217 case EltwiseInputMode::INTERMEDIATE_RESULTS_INDEX:
218 jit.AddConstant(MakeJitConstant(name, "tmp" + std::to_string(input.tmpIndex)));
225 std::string input0_str, input1_str, cast_type, output_cast, op;
229 cast_type = "(MAKE_VECTOR_TYPE(UNIT_TYPE, 8))";
230 op = "const MAKE_VECTOR_TYPE(UNIT_TYPE, 8) tmp" + op_num_str + " = ";
232 else if(params.int8_quantization)
235 op = "const int tmp" + op_num_str + " = ";
239 cast_type = "(UNIT_TYPE)";
240 op = "const UNIT_TYPE tmp" + op_num_str + " = ";
243 if (params.output.GetDType() == Datatype::INT8 && !params.int8_quantization) {
244 output_cast = "(char)";
245 cast_type = "(" + toCLType(params.inputs[op_num].GetDType()) + ")";
248 input0_str = cast_type + "INPUT_" + op_num_str + "_0";
249 input1_str = cast_type + "INPUT_" + op_num_str + "_1";
251 if (ew.mode == EltwiseMode::ADD)
253 std::vector<std::string> coeff_strings(ew.inputs.size(), "");
254 for (size_t input_idx = 0; input_idx < ew.inputs.size(); input_idx++)
256 const auto& input = ew.inputs[input_idx];
257 if (input.mode == EltwiseInputMode::INPUT_BUFFER && input.index < coefficients.size())
259 const float c = coefficients[input.index];
261 coeff_strings[input_idx] = cast_type + "(" + std::to_string(c) + ")*";
265 input0_str = coeff_strings[0] + input0_str;
266 input1_str = coeff_strings[1] + input1_str;
272 case EltwiseMode::ADD: op += input0_str + " + " + input1_str; break;
273 case EltwiseMode::SUB: op += input0_str + " - " + input1_str; break;
274 case EltwiseMode::MUL: op += input0_str + " * " + input1_str; break;
275 case EltwiseMode::DIV: op += input0_str + " / " + input1_str; break;
276 case EltwiseMode::MODULU:
277 case EltwiseMode::MIN:
278 case EltwiseMode::MAX:
280 auto mode = (ew.mode == EltwiseMode::MODULU ? "mod" : (ew.mode == EltwiseMode::MIN ? "min" : "max" ));
281 auto input_0_type = params.inputs[0].GetDType();
282 auto input_1_type = params.inputs[1].GetDType();
285 if (input_0_type == kernel_selector::Datatype::INT8 ||
286 input_0_type == kernel_selector::Datatype::INT32 ||
287 input_0_type == kernel_selector::Datatype::INT64)
289 // input_0 == int && input_1 == int
290 if (input_1_type == kernel_selector::Datatype::INT8 ||
291 input_1_type == kernel_selector::Datatype::INT32 ||
292 input_1_type == kernel_selector::Datatype::INT64)
294 if (ew.mode == EltwiseMode::MODULU)
295 op += input0_str + " % " + input1_str;
297 op += cast_type + mode + "(" + input0_str + ", " + input1_str + ")";
299 // input_0 == int && input_1 != int
302 op += cast_type + "f" + mode + "(convert_float(" + input0_str + "), " + input1_str + ")";
305 // input_0 != int && input_1 == int
306 else if ( input_1_type == kernel_selector::Datatype::INT8 ||
307 input_1_type == kernel_selector::Datatype::INT32 ||
308 input_1_type == kernel_selector::Datatype::INT64)
310 op += cast_type + "f" + mode + "(" + input0_str + ", convert_float(" + input1_str + "))";
312 // input_0 != int && input_1 != int
315 op += cast_type + "f" + mode + "(" + input0_str + ", " + input1_str + ")";
318 case EltwiseMode::POW: op += cast_type + "pow(" + input0_str + ", " + input1_str + ")"; break;
319 case EltwiseMode::SQRT: op += cast_type + "sqrt(" + input0_str + ")"; break;
320 case EltwiseMode::RSQRT: op += cast_type + "1/sqrt(" + input0_str + ")"; break;
321 case EltwiseMode::SQUARED_DIFF: op += cast_type + "((" + input0_str + " - " + input1_str + ")"
322 " * (" + input0_str + " - " + input1_str + "))"; break;
323 case EltwiseMode::EQ: op += output_cast + "(" + input0_str + " == " + input1_str + ")"; break;
324 case EltwiseMode::NE: op += output_cast + "(" + input0_str + " != " + input1_str + ")"; break;
325 case EltwiseMode::LT: op += output_cast + "(" + input0_str + " < " + input1_str + ")"; break;
326 case EltwiseMode::LE: op += output_cast + "(" + input0_str + " <= " + input1_str + ")"; break;
327 case EltwiseMode::GT: op += output_cast + "(" + input0_str + " > " + input1_str + ")"; break;
328 case EltwiseMode::GE: op += output_cast + "(" + input0_str + " >= " + input1_str + ")"; break;
329 case EltwiseMode::LOGIC_AND: op += output_cast + "(" + input0_str + " && " + input1_str + ")"; break;
330 case EltwiseMode::LOGIC_OR: op += output_cast + "(" + input0_str + " || " + input1_str + ")"; break;
331 case EltwiseMode::LOGIC_XOR: op += output_cast + "(!" + input0_str + " != !" + input1_str + ")"; break;
332 case EltwiseMode::ASSIGN: op += input0_str; break;
337 std::string opname = "OPERATION" + op_num_str;
338 jit.AddConstant(MakeJitConstant(opname, op));
339 do_eltwise += "\\\n\t" + opname + ";";
342 for (size_t update_input_idx = 0; update_input_idx < updateInputs.size(); update_input_idx++)
343 do_eltwise += "\\\n\tinput" + std::to_string(updateInputs[update_input_idx].inputId) +
344 "[GET_INDEX(INPUT, " + std::to_string(updateInputs[update_input_idx].inputId) +
345 ")] = tmp" + std::to_string(updateInputs[update_input_idx].tmpId) + ";";
347 do_eltwise += "\\\n\tres = tmp" + std::to_string(operations.size() - 1) + ";";
349 jit.AddConstant(MakeJitConstant("DO_ELTWISE", do_eltwise));
351 if (params.layoutBased || params.int8_quantization || params.broadcast)
353 jit.Merge(GetTensorFriendlyWorkGroupsJit(params.output));
356 if (!params.stride.empty())
358 jit.AddConstant(MakeJitConstant("INPUT_STRIDED", 1));
364 JitConstants EltwiseKernelBase::GetJitConstants(const eltwise_params& params) const
366 return GetJitConstantsCommon(params, false);
369 EltwiseKernelBase::DispatchData EltwiseKernelBase::SetDefault(const eltwise_params& params) const
373 if (params.layoutBased || params.int8_quantization || params.broadcast)
375 auto global = GetTensorFriendlyWorkGroups(params.output);
379 if (!params.stride.empty())
381 kd.gws0 /= params.stride[0].x;
382 kd.gws0 /= params.stride[0].y;
385 else if (CheckInputsOutputNoPitchSameDims(params))
387 kd.gws0 = params.inputs[0].LogicalSize();
393 const auto& out = params.output;
395 std::vector<size_t> gws;
396 for (const auto& o : out.GetDims())
401 for (size_t i = gws.size(); i < 4; i++)
408 kd.gws2 = gws[2] * gws[3];
411 auto local = GetOptimalLocalWorkGroupSizes( { kd.gws0, kd.gws1, kd.gws2 } );
419 KernelsData EltwiseKernelBase::GetCommonKernelsData(const Params& params, const optional_params& options) const
421 if (!Validate(params, options))
426 KernelData kd = KernelData::Default<eltwise_params>(params);
427 eltwise_params& newParams = *static_cast<eltwise_params*>(kd.params.get());
429 auto entry_point = GetEntryPoint(kernelName, newParams.layerID, options);
430 auto cldnn_jit = GetJitConstants(newParams);
431 std::string jit = CreateJit(kernelName, cldnn_jit, entry_point);
433 DispatchData runInfo = SetDefault(newParams);
435 auto& kernel = kd.kernels[0];
437 kernel.workGroups.global = { runInfo.gws0, runInfo.gws1, runInfo.gws2 };
438 kernel.workGroups.local = { runInfo.lws0, runInfo.lws1, runInfo.lws2 };
440 kernel.kernelString = GetKernelString(kernelName, jit, entry_point, params.engineInfo, DEFAULT);
441 kernel.arguments = GetArgsDesc((uint32_t)newParams.inputs.size(), false, false, newParams.int8_quantization, newParams.output_calibration);
443 kd.estimatedTime = DONT_USE_IF_HAVE_SOMETHING_ELSE;